From c71588e18c621b9d632e344bcb0dfac1b4969db3 Mon Sep 17 00:00:00 2001 From: crazycs Date: Mon, 25 Feb 2019 19:00:42 +0800 Subject: [PATCH] executor,sessionctx: Add correctness for more system variables (#7196) (#9420) --- executor/set_test.go | 137 ++++++++++++++++++++++++++ sessionctx/variable/sysvar.go | 23 +++-- sessionctx/variable/varsutil.go | 167 +++++++++++++------------------- 3 files changed, 220 insertions(+), 107 deletions(-) diff --git a/executor/set_test.go b/executor/set_test.go index 0bb16bd837d26..0e6bdb11d4997 100644 --- a/executor/set_test.go +++ b/executor/set_test.go @@ -327,4 +327,141 @@ func (s *testSuite) TestValidateSetVar(c *C) { tk.MustExec("set time_zone='SySTeM'") result = tk.MustQuery("select @@time_zone;") result.Check(testkit.Rows("SYSTEM")) + + // The following cases test value out of range and illegal type when setting system variables. + // See https://dev.mysql.com/doc/refman/5.7/en/server-system-variables.html for more details. + tk.MustExec("set @@global.max_connections=100001") + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1292|Truncated incorrect max_connections value: '100001'")) + result = tk.MustQuery("select @@global.max_connections;") + result.Check(testkit.Rows("100000")) + + tk.MustExec("set @@global.max_connections=-1") + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1292|Truncated incorrect max_connections value: '-1'")) + result = tk.MustQuery("select @@global.max_connections;") + result.Check(testkit.Rows("1")) + + _, err = tk.Exec("set @@global.max_connections='hello'") + c.Assert(terror.ErrorEqual(err, variable.ErrWrongTypeForVar), IsTrue) + + tk.MustExec("set @@global.max_connect_errors=18446744073709551615") + + tk.MustExec("set @@global.max_connect_errors=-1") + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1292|Truncated incorrect max_connect_errors value: '-1'")) + result = tk.MustQuery("select @@global.max_connect_errors;") + result.Check(testkit.Rows("1")) + + _, err = tk.Exec("set @@global.max_connect_errors=18446744073709551616") + c.Assert(terror.ErrorEqual(err, variable.ErrWrongTypeForVar), IsTrue) + + tk.MustExec("set @@global.max_connections=100001") + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1292|Truncated incorrect max_connections value: '100001'")) + result = tk.MustQuery("select @@global.max_connections;") + result.Check(testkit.Rows("100000")) + + tk.MustExec("set @@global.max_connections=-1") + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1292|Truncated incorrect max_connections value: '-1'")) + result = tk.MustQuery("select @@global.max_connections;") + result.Check(testkit.Rows("1")) + + _, err = tk.Exec("set @@global.max_connections='hello'") + c.Assert(terror.ErrorEqual(err, variable.ErrWrongTypeForVar), IsTrue) + + tk.MustExec("set @@max_sort_length=1") + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1292|Truncated incorrect max_sort_length value: '1'")) + result = tk.MustQuery("select @@max_sort_length;") + result.Check(testkit.Rows("4")) + + tk.MustExec("set @@max_sort_length=-100") + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1292|Truncated incorrect max_sort_length value: '-100'")) + result = tk.MustQuery("select @@max_sort_length;") + result.Check(testkit.Rows("4")) + + tk.MustExec("set @@max_sort_length=8388609") + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1292|Truncated incorrect max_sort_length value: '8388609'")) + result = tk.MustQuery("select @@max_sort_length;") + result.Check(testkit.Rows("8388608")) + + _, err = tk.Exec("set @@max_sort_length='hello'") + c.Assert(terror.ErrorEqual(err, variable.ErrWrongTypeForVar), IsTrue) + + tk.MustExec("set @@global.table_definition_cache=399") + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1292|Truncated incorrect table_definition_cache value: '399'")) + result = tk.MustQuery("select @@global.table_definition_cache;") + result.Check(testkit.Rows("400")) + + tk.MustExec("set @@global.table_definition_cache=-1") + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1292|Truncated incorrect table_definition_cache value: '-1'")) + result = tk.MustQuery("select @@global.table_definition_cache;") + result.Check(testkit.Rows("400")) + + tk.MustExec("set @@global.table_definition_cache=524289") + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1292|Truncated incorrect table_definition_cache value: '524289'")) + result = tk.MustQuery("select @@global.table_definition_cache;") + result.Check(testkit.Rows("524288")) + + _, err = tk.Exec("set @@global.table_definition_cache='hello'") + c.Assert(terror.ErrorEqual(err, variable.ErrWrongTypeForVar), IsTrue) + + tk.MustExec("set @@old_passwords=-1") + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1292|Truncated incorrect old_passwords value: '-1'")) + result = tk.MustQuery("select @@old_passwords;") + result.Check(testkit.Rows("0")) + + tk.MustExec("set @@old_passwords=3") + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1292|Truncated incorrect old_passwords value: '3'")) + result = tk.MustQuery("select @@old_passwords;") + result.Check(testkit.Rows("2")) + + _, err = tk.Exec("set @@old_passwords='hello'") + c.Assert(terror.ErrorEqual(err, variable.ErrWrongTypeForVar), IsTrue) + + tk.MustExec("set @@tmp_table_size=-1") + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1292|Truncated incorrect tmp_table_size value: '-1'")) + result = tk.MustQuery("select @@tmp_table_size;") + result.Check(testkit.Rows("1024")) + + tk.MustExec("set @@tmp_table_size=1020") + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1292|Truncated incorrect tmp_table_size value: '1020'")) + result = tk.MustQuery("select @@tmp_table_size;") + result.Check(testkit.Rows("1024")) + + tk.MustExec("set @@tmp_table_size=167772161") + result = tk.MustQuery("select @@tmp_table_size;") + result.Check(testkit.Rows("167772161")) + + tk.MustExec("set @@tmp_table_size=18446744073709551615") + result = tk.MustQuery("select @@tmp_table_size;") + result.Check(testkit.Rows("18446744073709551615")) + + _, err = tk.Exec("set @@tmp_table_size=18446744073709551616") + c.Assert(terror.ErrorEqual(err, variable.ErrWrongTypeForVar), IsTrue) + + _, err = tk.Exec("set @@tmp_table_size='hello'") + c.Assert(terror.ErrorEqual(err, variable.ErrWrongTypeForVar), IsTrue) + + tk.MustExec("set @@global.connect_timeout=1") + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1292|Truncated incorrect connect_timeout value: '1'")) + result = tk.MustQuery("select @@global.connect_timeout;") + result.Check(testkit.Rows("2")) + + tk.MustExec("set @@global.connect_timeout=31536000") + result = tk.MustQuery("select @@global.connect_timeout;") + result.Check(testkit.Rows("31536000")) + + tk.MustExec("set @@global.connect_timeout=31536001") + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1292|Truncated incorrect connect_timeout value: '31536001'")) + result = tk.MustQuery("select @@global.connect_timeout;") + result.Check(testkit.Rows("31536000")) + + result = tk.MustQuery("select @@sql_select_limit;") + result.Check(testkit.Rows("18446744073709551615")) + tk.MustExec("set @@sql_select_limit=default") + result = tk.MustQuery("select @@sql_select_limit;") + result.Check(testkit.Rows("18446744073709551615")) + + tk.MustExec("set @@global.flush_time=31536001") + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1292|Truncated incorrect flush_time value: '31536001'")) + + tk.MustExec("set @@global.interactive_timeout=31536001") + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1292|Truncated incorrect interactive_timeout value: '31536001'")) } diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index 5e9b2ef07ac54..708351e93ef31 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -20,6 +20,7 @@ import ( "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/terror" + "github.com/pingcap/tidb/util/charset" "github.com/pingcap/tidb/util/logutil" ) @@ -127,7 +128,7 @@ var defaultSysVars = []*SysVar{ {ScopeSession, "rand_seed2", ""}, {ScopeGlobal, "validate_password_number_count", "1"}, {ScopeSession, "gtid_next", ""}, - {ScopeGlobal | ScopeSession, "sql_select_limit", "18446744073709551615"}, + {ScopeGlobal | ScopeSession, SQLSelectLimit, "18446744073709551615"}, {ScopeGlobal, "ndb_show_foreign_key_mock_tables", ""}, {ScopeNone, "multi_range_count", "256"}, {ScopeGlobal | ScopeSession, DefaultWeekFormat, "0"}, @@ -135,7 +136,7 @@ var defaultSysVars = []*SysVar{ {ScopeGlobal, "slave_transaction_retries", "10"}, {ScopeGlobal | ScopeSession, "default_storage_engine", "InnoDB"}, {ScopeNone, "ft_query_expansion_limit", "20"}, - {ScopeGlobal, "max_connect_errors", "100"}, + {ScopeGlobal, MaxConnectErrors, "100"}, {ScopeGlobal, "sync_binlog", "0"}, {ScopeNone, "max_digest_length", "1024"}, {ScopeNone, "innodb_force_load_corrupted", "OFF"}, @@ -145,7 +146,7 @@ var defaultSysVars = []*SysVar{ {ScopeGlobal, "log_backward_compatible_user_definitions", ""}, {ScopeNone, "lc_messages_dir", "/usr/local/mysql-5.6.25-osx10.8-x86_64/share/"}, {ScopeGlobal, "ft_boolean_syntax", "+ -><()~*:\"\"&|"}, - {ScopeGlobal, "table_definition_cache", "1400"}, + {ScopeGlobal, TableDefinitionCache, "-1"}, {ScopeNone, SkipNameResolve, "0"}, {ScopeNone, "performance_schema_max_file_handles", "32768"}, {ScopeSession, "transaction_allow_batching", ""}, @@ -153,7 +154,7 @@ var defaultSysVars = []*SysVar{ {ScopeNone, "performance_schema_max_statement_classes", "168"}, {ScopeGlobal, "server_id", "0"}, {ScopeGlobal, "innodb_flushing_avg_loops", "30"}, - {ScopeGlobal | ScopeSession, "tmp_table_size", "16777216"}, + {ScopeGlobal | ScopeSession, TmpTableSize, "16777216"}, {ScopeGlobal, "innodb_max_purge_lag", "0"}, {ScopeGlobal | ScopeSession, "preload_buffer_size", "32768"}, {ScopeGlobal, "slave_checkpoint_period", "300"}, @@ -162,8 +163,8 @@ var defaultSysVars = []*SysVar{ {ScopeGlobal, "innodb_flush_log_at_timeout", "1"}, {ScopeGlobal, "innodb_max_undo_log_size", ""}, {ScopeGlobal | ScopeSession, "range_alloc_block_size", "4096"}, - {ScopeGlobal, "connect_timeout", "10"}, - {ScopeGlobal | ScopeSession, "collation_server", "latin1_swedish_ci"}, + {ScopeGlobal, ConnectTimeout, "10"}, + {ScopeGlobal | ScopeSession, "collation_server", charset.CollationUTF8}, {ScopeNone, "have_rtree_keys", "YES"}, {ScopeGlobal, "innodb_old_blocks_pct", "37"}, {ScopeGlobal, "innodb_file_format", "Antelope"}, @@ -742,6 +743,16 @@ const ( ErrorCount = "error_count" // BlockEncryptionMode is the name for 'block_encryption_mode' system variable. BlockEncryptionMode = "block_encryption_mode" + // SQLSelectLimit is the name for 'sql_select_limit' system variable. + SQLSelectLimit = "sql_select_limit" + // MaxConnectErrors is the name for 'max_connect_errors' system variable. + MaxConnectErrors = "max_connect_errors" + // TableDefinitionCache is the name for 'table_definition_cache' system variable. + TableDefinitionCache = "table_definition_cache" + // TmpTableSize is the name for 'tmp_table_size' system variable. + TmpTableSize = "tmp_table_size" + // ConnectTimeout is the name for 'connect_timeout' system variable. + ConnectTimeout = "connect_timeout" ) // GlobalVarAccessor is the interface for accessing global scope system and status variables. diff --git a/sessionctx/variable/varsutil.go b/sessionctx/variable/varsutil.go index 0c195cbfdb35f..de0b094293a2b 100644 --- a/sessionctx/variable/varsutil.go +++ b/sessionctx/variable/varsutil.go @@ -16,6 +16,7 @@ package variable import ( "encoding/json" "fmt" + "math" "strconv" "strings" "sync/atomic" @@ -27,6 +28,9 @@ import ( "github.com/pingcap/tidb/types" ) +// secondsPerYear represents seconds in a normal year. Leap year is not considered here. +const secondsPerYear = 60 * 60 * 24 * 365 + // SetDDLReorgWorkerCounter sets ddlReorgWorkerCounter count. // Max worker count is maxDDLReorgWorkerCount. func SetDDLReorgWorkerCounter(cnt int32) { @@ -164,6 +168,46 @@ func ValidateGetSystemVar(name string, isGlobal bool) error { return nil } +func checkUInt64SystemVar(name, value string, min, max uint64, vars *SessionVars) (string, error) { + if value[0] == '-' { + _, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return value, ErrWrongTypeForVar.GenByArgs(name) + } + vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenByArgs(name, value)) + return fmt.Sprintf("%d", min), nil + } + val, err := strconv.ParseUint(value, 10, 64) + if err != nil { + return value, ErrWrongTypeForVar.GenByArgs(name) + } + if val < min { + vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenByArgs(name, value)) + return fmt.Sprintf("%d", min), nil + } + if val > max { + vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenByArgs(name, value)) + return fmt.Sprintf("%d", max), nil + } + return value, nil +} + +func checkInt64SystemVar(name, value string, min, max int64, vars *SessionVars) (string, error) { + val, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return value, ErrWrongTypeForVar.GenByArgs(name) + } + if val < min { + vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenByArgs(name, value)) + return fmt.Sprintf("%d", min), nil + } + if val > max { + vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenByArgs(name, value)) + return fmt.Sprintf("%d", max), nil + } + return value, nil +} + // ValidateSetSystemVar checks if system variable satisfies specific restriction. func ValidateSetSystemVar(vars *SessionVars, name string, value string) (string, error) { if strings.EqualFold(value, "DEFAULT") { @@ -173,19 +217,10 @@ func ValidateSetSystemVar(vars *SessionVars, name string, value string) (string, return value, UnknownSystemVar.GenByArgs(name) } switch name { + case ConnectTimeout: + return checkUInt64SystemVar(name, value, 2, secondsPerYear, vars) case DefaultWeekFormat: - val, err := strconv.Atoi(value) - if err != nil { - return value, ErrWrongTypeForVar.GenByArgs(name) - } - if val < 0 { - vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenByArgs(name, value)) - return "0", nil - } - if val > 7 { - vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenByArgs(name, value)) - return "7", nil - } + return checkUInt64SystemVar(name, value, 0, 7, vars) case DelayKeyWrite: if strings.EqualFold(value, "ON") || value == "1" { return "ON", nil @@ -196,101 +231,25 @@ func ValidateSetSystemVar(vars *SessionVars, name string, value string) (string, } return value, ErrWrongValueForVar.GenByArgs(name, value) case FlushTime: - val, err := strconv.Atoi(value) - if err != nil { - return value, ErrWrongTypeForVar.GenByArgs(name) - } - if val < 0 { - vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenByArgs(name, value)) - return "0", nil - } + return checkUInt64SystemVar(name, value, 0, secondsPerYear, vars) case GroupConcatMaxLen: - val, err := strconv.ParseUint(value, 10, 64) - if err != nil { - return value, ErrWrongTypeForVar.GenByArgs(name) - } - if val < 4 { - vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenByArgs(name, value)) - return "4", nil - } - if val > 18446744073709551615 { - vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenByArgs(name, value)) - return "18446744073709551615", nil - } + // The reasonable range of 'group_concat_max_len' is 4~18446744073709551615(64-bit platforms) + // See https://dev.mysql.com/doc/refman/8.0/en/server-system-variables.html#sysvar_group_concat_max_len for details + return checkUInt64SystemVar(name, value, 4, math.MaxUint64, vars) case InteractiveTimeout: - val, err := strconv.Atoi(value) - if err != nil { - return value, ErrWrongTypeForVar.GenByArgs(name) - } - if val < 1 { - vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenByArgs(name, value)) - return "1", nil - } + return checkUInt64SystemVar(name, value, 1, secondsPerYear, vars) case MaxConnections: - val, err := strconv.Atoi(value) - if err != nil { - return value, ErrWrongTypeForVar.GenByArgs(name) - } - if val < 1 { - vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenByArgs(name, value)) - return "1", nil - } - if val > 100000 { - vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenByArgs(name, value)) - return "100000", nil - } + return checkUInt64SystemVar(name, value, 1, 100000, vars) + case MaxConnectErrors: + return checkUInt64SystemVar(name, value, 1, math.MaxUint64, vars) case MaxSortLength: - val, err := strconv.ParseInt(value, 10, 64) - if err != nil { - return value, ErrWrongTypeForVar.GenByArgs(name) - } - if val < 4 { - vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenByArgs(name, value)) - return "4", nil - } - if val > 8388608 { - vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenByArgs(name, value)) - return "8388608", nil - } + return checkUInt64SystemVar(name, value, 4, 8388608, vars) case MaxSpRecursionDepth: - val, err := strconv.ParseInt(value, 10, 64) - if err != nil { - return value, ErrWrongTypeForVar.GenByArgs(name) - } - if val < 0 { - vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenByArgs(name, value)) - return "0", nil - } - if val > 255 { - vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenByArgs(name, value)) - return "255", nil - } - case OldPasswords: - val, err := strconv.Atoi(value) - if err != nil { - return value, ErrWrongTypeForVar.GenByArgs(name) - } - if val < 0 { - vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenByArgs(name, value)) - return "0", nil - } - if val > 2 { - vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenByArgs(name, value)) - return "2", nil - } + return checkUInt64SystemVar(name, value, 0, 255, vars) case MaxUserConnections: - val, err := strconv.ParseUint(value, 10, 64) - if err != nil { - return value, ErrWrongTypeForVar.GenByArgs(name) - } - if val < 0 { - vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenByArgs(name, value)) - return "0", nil - } - if val > 4294967295 { - vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenByArgs(name, value)) - return "4294967295", nil - } + return checkUInt64SystemVar(name, value, 0, 4294967295, vars) + case OldPasswords: + return checkUInt64SystemVar(name, value, 0, 2, vars) case SessionTrackGtids: if strings.EqualFold(value, "OFF") || value == "0" { return "OFF", nil @@ -300,6 +259,12 @@ func ValidateSetSystemVar(vars *SessionVars, name string, value string) (string, return "ALL_GTIDS", nil } return value, ErrWrongValueForVar.GenByArgs(name, value) + case SQLSelectLimit: + return checkUInt64SystemVar(name, value, 0, math.MaxUint64, vars) + case TableDefinitionCache: + return checkUInt64SystemVar(name, value, 400, 524288, vars) + case TmpTableSize: + return checkUInt64SystemVar(name, value, 1024, math.MaxUint64, vars) case TimeZone: if strings.EqualFold(value, "SYSTEM") { return "SYSTEM", nil