diff --git a/README.md b/README.md index ec1f533775e24..8d14579d316e6 100644 --- a/README.md +++ b/README.md @@ -62,7 +62,7 @@ For support, please contact [PingCAP](http://bit.ly/contact_us_via_github). ### To start using TiDB -See [Quick Start Guide](https://pingcap.com/docs/stable/quick-start-with-tidb/). +See [Quick Start Guide](https://docs.pingcap.com/tidb/stable/quick-start-with-tidb). ### To start developing TiDB diff --git a/bindinfo/bind_serial_test.go b/bindinfo/bind_serial_test.go new file mode 100644 index 0000000000000..87fa5c4eb4baa --- /dev/null +++ b/bindinfo/bind_serial_test.go @@ -0,0 +1,931 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bindinfo_test + +import ( + "context" + "fmt" + "testing" + + "github.com/pingcap/parser" + "github.com/pingcap/parser/auth" + "github.com/pingcap/parser/model" + "github.com/pingcap/parser/terror" + "github.com/pingcap/tidb/bindinfo" + "github.com/pingcap/tidb/config" + "github.com/pingcap/tidb/domain" + "github.com/pingcap/tidb/testkit" + "github.com/stretchr/testify/require" +) + +func TestExplain(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1") + tk.MustExec("drop table if exists t2") + tk.MustExec("create table t1(id int)") + tk.MustExec("create table t2(id int)") + + require.True(t, tk.HasPlan("SELECT * from t1,t2 where t1.id = t2.id", "HashJoin")) + require.True(t, tk.HasPlan("SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id", "MergeJoin")) + + tk.MustExec("create global binding for SELECT * from t1,t2 where t1.id = t2.id using SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id") + + require.True(t, tk.HasPlan("SELECT * from t1,t2 where t1.id = t2.id", "MergeJoin")) + + tk.MustExec("drop global binding for SELECT * from t1,t2 where t1.id = t2.id") + + // Add test for SetOprStmt + tk.MustExec("create index index_id on t1(id)") + require.False(t, tk.HasPlan("SELECT * from t1 union SELECT * from t1", "IndexReader")) + require.True(t, tk.HasPlan("SELECT * from t1 use index(index_id) union SELECT * from t1", "IndexReader")) + + tk.MustExec("create global binding for SELECT * from t1 union SELECT * from t1 using SELECT * from t1 use index(index_id) union SELECT * from t1") + + require.True(t, tk.HasPlan("SELECT * from t1 union SELECT * from t1", "IndexReader")) + + tk.MustExec("drop global binding for SELECT * from t1 union SELECT * from t1") +} + +// TestBindingSymbolList tests sql with "?, ?, ?, ?", fixes #13871 +func TestBindingSymbolList(t *testing.T) { + store, dom, clean := testkit.CreateMockStoreAndDomain(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int, b int, INDEX ia (a), INDEX ib (b));") + tk.MustExec("insert into t value(1, 1);") + + // before binding + tk.MustQuery("select a, b from t where a = 3 limit 1, 100") + require.Equal(t, "t:ia", tk.Session().GetSessionVars().StmtCtx.IndexNames[0]) + require.True(t, tk.MustUseIndex("select a, b from t where a = 3 limit 1, 100", "ia(a)")) + + tk.MustExec(`create global binding for select a, b from t where a = 1 limit 0, 1 using select a, b from t use index (ib) where a = 1 limit 0, 1`) + + // after binding + tk.MustQuery("select a, b from t where a = 3 limit 1, 100") + require.Equal(t, "t:ib", tk.Session().GetSessionVars().StmtCtx.IndexNames[0]) + require.True(t, tk.MustUseIndex("select a, b from t where a = 3 limit 1, 100", "ib(b)")) + + // Normalize + sql, hash := parser.NormalizeDigest("select a, b from test . t where a = 1 limit 0, 1") + + bindData := dom.BindHandle().GetBindRecord(hash.String(), sql, "test") + require.NotNil(t, bindData) + require.Equal(t, "select `a` , `b` from `test` . `t` where `a` = ? limit ...", bindData.OriginalSQL) + bind := bindData.Bindings[0] + require.Equal(t, "SELECT `a`,`b` FROM `test`.`t` USE INDEX (`ib`) WHERE `a` = 1 LIMIT 0,1", bind.BindSQL) + require.Equal(t, "test", bindData.Db) + require.Equal(t, "using", bind.Status) + require.NotNil(t, bind.Charset) + require.NotNil(t, bind.Collation) + require.NotNil(t, bind.CreateTime) + require.NotNil(t, bind.UpdateTime) +} + +func TestDMLSQLBind(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1, t2") + tk.MustExec("create table t1(a int, b int, c int, key idx_b(b), key idx_c(c))") + tk.MustExec("create table t2(a int, b int, c int, key idx_b(b), key idx_c(c))") + + tk.MustExec("delete from t1 where b = 1 and c > 1") + require.Equal(t, "t1:idx_b", tk.Session().GetSessionVars().StmtCtx.IndexNames[0]) + require.True(t, tk.MustUseIndex("delete from t1 where b = 1 and c > 1", "idx_b(b)")) + tk.MustExec("create global binding for delete from t1 where b = 1 and c > 1 using delete /*+ use_index(t1,idx_c) */ from t1 where b = 1 and c > 1") + tk.MustExec("delete from t1 where b = 1 and c > 1") + require.Equal(t, "t1:idx_c", tk.Session().GetSessionVars().StmtCtx.IndexNames[0]) + require.True(t, tk.MustUseIndex("delete from t1 where b = 1 and c > 1", "idx_c(c)")) + + require.True(t, tk.HasPlan("delete t1, t2 from t1 inner join t2 on t1.b = t2.b", "HashJoin")) + tk.MustExec("create global binding for delete t1, t2 from t1 inner join t2 on t1.b = t2.b using delete /*+ inl_join(t1) */ t1, t2 from t1 inner join t2 on t1.b = t2.b") + require.True(t, tk.HasPlan("delete t1, t2 from t1 inner join t2 on t1.b = t2.b", "IndexJoin")) + + tk.MustExec("update t1 set a = 1 where b = 1 and c > 1") + require.Equal(t, "t1:idx_b", tk.Session().GetSessionVars().StmtCtx.IndexNames[0]) + require.True(t, tk.MustUseIndex("update t1 set a = 1 where b = 1 and c > 1", "idx_b(b)")) + tk.MustExec("create global binding for update t1 set a = 1 where b = 1 and c > 1 using update /*+ use_index(t1,idx_c) */ t1 set a = 1 where b = 1 and c > 1") + tk.MustExec("delete from t1 where b = 1 and c > 1") + require.Equal(t, "t1:idx_c", tk.Session().GetSessionVars().StmtCtx.IndexNames[0]) + require.True(t, tk.MustUseIndex("update t1 set a = 1 where b = 1 and c > 1", "idx_c(c)")) + + require.True(t, tk.HasPlan("update t1, t2 set t1.a = 1 where t1.b = t2.b", "HashJoin")) + tk.MustExec("create global binding for update t1, t2 set t1.a = 1 where t1.b = t2.b using update /*+ inl_join(t1) */ t1, t2 set t1.a = 1 where t1.b = t2.b") + require.True(t, tk.HasPlan("update t1, t2 set t1.a = 1 where t1.b = t2.b", "IndexJoin")) + + tk.MustExec("insert into t1 select * from t2 where t2.b = 2 and t2.c > 2") + require.Equal(t, "t2:idx_b", tk.Session().GetSessionVars().StmtCtx.IndexNames[0]) + require.True(t, tk.MustUseIndex("insert into t1 select * from t2 where t2.b = 2 and t2.c > 2", "idx_b(b)")) + tk.MustExec("create global binding for insert into t1 select * from t2 where t2.b = 1 and t2.c > 1 using insert /*+ use_index(t2,idx_c) */ into t1 select * from t2 where t2.b = 1 and t2.c > 1") + tk.MustExec("insert into t1 select * from t2 where t2.b = 2 and t2.c > 2") + require.Equal(t, "t2:idx_b", tk.Session().GetSessionVars().StmtCtx.IndexNames[0]) + require.True(t, tk.MustUseIndex("insert into t1 select * from t2 where t2.b = 2 and t2.c > 2", "idx_b(b)")) + tk.MustExec("drop global binding for insert into t1 select * from t2 where t2.b = 1 and t2.c > 1") + tk.MustExec("create global binding for insert into t1 select * from t2 where t2.b = 1 and t2.c > 1 using insert into t1 select /*+ use_index(t2,idx_c) */ * from t2 where t2.b = 1 and t2.c > 1") + tk.MustExec("insert into t1 select * from t2 where t2.b = 2 and t2.c > 2") + require.Equal(t, "t2:idx_c", tk.Session().GetSessionVars().StmtCtx.IndexNames[0]) + require.True(t, tk.MustUseIndex("insert into t1 select * from t2 where t2.b = 2 and t2.c > 2", "idx_c(c)")) + + tk.MustExec("replace into t1 select * from t2 where t2.b = 2 and t2.c > 2") + require.Equal(t, "t2:idx_b", tk.Session().GetSessionVars().StmtCtx.IndexNames[0]) + require.True(t, tk.MustUseIndex("replace into t1 select * from t2 where t2.b = 2 and t2.c > 2", "idx_b(b)")) + tk.MustExec("create global binding for replace into t1 select * from t2 where t2.b = 1 and t2.c > 1 using replace into t1 select /*+ use_index(t2,idx_c) */ * from t2 where t2.b = 1 and t2.c > 1") + tk.MustExec("replace into t1 select * from t2 where t2.b = 2 and t2.c > 2") + require.Equal(t, "t2:idx_c", tk.Session().GetSessionVars().StmtCtx.IndexNames[0]) + require.True(t, tk.MustUseIndex("replace into t1 select * from t2 where t2.b = 2 and t2.c > 2", "idx_c(c)")) +} + +func TestBestPlanInBaselines(t *testing.T) { + store, dom, clean := testkit.CreateMockStoreAndDomain(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int, b int, INDEX ia (a), INDEX ib (b));") + tk.MustExec("insert into t value(1, 1);") + + // before binding + tk.MustQuery("select a, b from t where a = 3 limit 1, 100") + require.Equal(t, "t:ia", tk.Session().GetSessionVars().StmtCtx.IndexNames[0]) + require.True(t, tk.MustUseIndex("select a, b from t where a = 3 limit 1, 100", "ia(a)")) + + tk.MustQuery("select a, b from t where b = 3 limit 1, 100") + require.Equal(t, "t:ib", tk.Session().GetSessionVars().StmtCtx.IndexNames[0]) + require.True(t, tk.MustUseIndex("select a, b from t where b = 3 limit 1, 100", "ib(b)")) + + tk.MustExec(`create global binding for select a, b from t where a = 1 limit 0, 1 using select /*+ use_index(@sel_1 test.t ia) */ a, b from t where a = 1 limit 0, 1`) + tk.MustExec(`create global binding for select a, b from t where b = 1 limit 0, 1 using select /*+ use_index(@sel_1 test.t ib) */ a, b from t where b = 1 limit 0, 1`) + + sql, hash := utilNormalizeWithDefaultDB(t, "select a, b from t where a = 1 limit 0, 1", "test") + bindData := dom.BindHandle().GetBindRecord(hash, sql, "test") + require.NotNil(t, bindData) + require.Equal(t, "select `a` , `b` from `test` . `t` where `a` = ? limit ...", bindData.OriginalSQL) + bind := bindData.Bindings[0] + require.Equal(t, "SELECT /*+ use_index(@`sel_1` `test`.`t` `ia`)*/ `a`,`b` FROM `test`.`t` WHERE `a` = 1 LIMIT 0,1", bind.BindSQL) + require.Equal(t, "test", bindData.Db) + require.Equal(t, "using", bind.Status) + + tk.MustQuery("select a, b from t where a = 3 limit 1, 10") + require.Equal(t, "t:ia", tk.Session().GetSessionVars().StmtCtx.IndexNames[0]) + require.True(t, tk.MustUseIndex("select a, b from t where a = 3 limit 1, 100", "ia(a)")) + + tk.MustQuery("select a, b from t where b = 3 limit 1, 100") + require.Equal(t, "t:ib", tk.Session().GetSessionVars().StmtCtx.IndexNames[0]) + require.True(t, tk.MustUseIndex("select a, b from t where b = 3 limit 1, 100", "ib(b)")) +} + +func TestErrorBind(t *testing.T) { + store, dom, clean := testkit.CreateMockStoreAndDomain(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustGetErrMsg("create global binding for select * from t using select * from t", "[schema:1146]Table 'test.t' doesn't exist") + tk.MustExec("drop table if exists t") + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t(i int, s varchar(20))") + tk.MustExec("create table t1(i int, s varchar(20))") + tk.MustExec("create index index_t on t(i,s)") + + _, err := tk.Exec("create global binding for select * from t where i>100 using select * from t use index(index_t) where i>100") + require.NoError(t, err, "err %v", err) + + sql, hash := parser.NormalizeDigest("select * from test . t where i > ?") + bindData := dom.BindHandle().GetBindRecord(hash.String(), sql, "test") + require.NotNil(t, bindData) + require.Equal(t, "select * from `test` . `t` where `i` > ?", bindData.OriginalSQL) + bind := bindData.Bindings[0] + require.Equal(t, "SELECT * FROM `test`.`t` USE INDEX (`index_t`) WHERE `i` > 100", bind.BindSQL) + require.Equal(t, "test", bindData.Db) + require.Equal(t, "using", bind.Status) + require.NotNil(t, bind.Charset) + require.NotNil(t, bind.Collation) + require.NotNil(t, bind.CreateTime) + require.NotNil(t, bind.UpdateTime) + + tk.MustExec("drop index index_t on t") + _, err = tk.Exec("select * from t where i > 10") + require.NoError(t, err) + + dom.BindHandle().DropInvalidBindRecord() + + rs, err := tk.Exec("show global bindings") + require.NoError(t, err) + chk := rs.NewChunk() + err = rs.Next(context.TODO(), chk) + require.NoError(t, err) + require.Equal(t, 0, chk.NumRows()) +} + +func TestDMLEvolveBaselines(t *testing.T) { + originalVal := config.CheckTableBeforeDrop + config.CheckTableBeforeDrop = true + defer func() { + config.CheckTableBeforeDrop = originalVal + }() + + store, clean := testkit.CreateMockStore(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int, b int, c int, index idx_b(b), index idx_c(c))") + tk.MustExec("insert into t values (1,1,1), (2,2,2), (3,3,3), (4,4,4), (5,5,5)") + tk.MustExec("analyze table t") + tk.MustExec("set @@tidb_evolve_plan_baselines=1") + + tk.MustExec("create global binding for delete from t where b = 1 and c > 1 using delete /*+ use_index(t,idx_c) */ from t where b = 1 and c > 1") + rows := tk.MustQuery("show global bindings").Rows() + require.Len(t, rows, 1) + tk.MustExec("delete /*+ use_index(t,idx_b) */ from t where b = 2 and c > 1") + require.Equal(t, "t:idx_c", tk.Session().GetSessionVars().StmtCtx.IndexNames[0]) + tk.MustExec("admin flush bindings") + rows = tk.MustQuery("show global bindings").Rows() + require.Len(t, rows, 1) + tk.MustExec("admin evolve bindings") + rows = tk.MustQuery("show global bindings").Rows() + require.Len(t, rows, 1) + + tk.MustExec("create global binding for update t set a = 1 where b = 1 and c > 1 using update /*+ use_index(t,idx_c) */ t set a = 1 where b = 1 and c > 1") + rows = tk.MustQuery("show global bindings").Rows() + require.Len(t, rows, 2) + tk.MustExec("update /*+ use_index(t,idx_b) */ t set a = 2 where b = 2 and c > 1") + require.Equal(t, "t:idx_c", tk.Session().GetSessionVars().StmtCtx.IndexNames[0]) + tk.MustExec("admin flush bindings") + rows = tk.MustQuery("show global bindings").Rows() + require.Len(t, rows, 2) + tk.MustExec("admin evolve bindings") + rows = tk.MustQuery("show global bindings").Rows() + require.Len(t, rows, 2) + + tk.MustExec("create table t1 like t") + tk.MustExec("create global binding for insert into t1 select * from t where t.b = 1 and t.c > 1 using insert into t1 select /*+ use_index(t,idx_c) */ * from t where t.b = 1 and t.c > 1") + rows = tk.MustQuery("show global bindings").Rows() + require.Len(t, rows, 3) + tk.MustExec("insert into t1 select /*+ use_index(t,idx_b) */ * from t where t.b = 2 and t.c > 2") + require.Equal(t, "t:idx_c", tk.Session().GetSessionVars().StmtCtx.IndexNames[0]) + tk.MustExec("admin flush bindings") + rows = tk.MustQuery("show global bindings").Rows() + require.Len(t, rows, 3) + tk.MustExec("admin evolve bindings") + rows = tk.MustQuery("show global bindings").Rows() + require.Len(t, rows, 3) + + tk.MustExec("create global binding for replace into t1 select * from t where t.b = 1 and t.c > 1 using replace into t1 select /*+ use_index(t,idx_c) */ * from t where t.b = 1 and t.c > 1") + rows = tk.MustQuery("show global bindings").Rows() + require.Len(t, rows, 4) + tk.MustExec("replace into t1 select /*+ use_index(t,idx_b) */ * from t where t.b = 2 and t.c > 2") + require.Equal(t, "t:idx_c", tk.Session().GetSessionVars().StmtCtx.IndexNames[0]) + tk.MustExec("admin flush bindings") + rows = tk.MustQuery("show global bindings").Rows() + require.Len(t, rows, 4) + tk.MustExec("admin evolve bindings") + rows = tk.MustQuery("show global bindings").Rows() + require.Len(t, rows, 4) +} + +func TestAddEvolveTasks(t *testing.T) { + originalVal := config.CheckTableBeforeDrop + config.CheckTableBeforeDrop = true + defer func() { + config.CheckTableBeforeDrop = originalVal + }() + + store, clean := testkit.CreateMockStore(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int, b int, c int, index idx_a(a), index idx_b(b), index idx_c(c))") + tk.MustExec("insert into t values (1,1,1), (2,2,2), (3,3,3), (4,4,4), (5,5,5)") + tk.MustExec("analyze table t") + tk.MustExec("create global binding for select * from t where a >= 1 and b >= 1 and c = 0 using select * from t use index(idx_a) where a >= 1 and b >= 1 and c = 0") + tk.MustExec("set @@tidb_evolve_plan_baselines=1") + // It cannot choose table path although it has lowest cost. + tk.MustQuery("select * from t where a >= 4 and b >= 1 and c = 0") + require.Equal(t, "t:idx_a", tk.Session().GetSessionVars().StmtCtx.IndexNames[0]) + tk.MustExec("admin flush bindings") + rows := tk.MustQuery("show global bindings").Rows() + require.Len(t, rows, 2) + require.Equal(t, "SELECT /*+ use_index(@`sel_1` `test`.`t` )*/ * FROM `test`.`t` WHERE `a` >= 4 AND `b` >= 1 AND `c` = 0", rows[0][1]) + require.Equal(t, "pending verify", rows[0][3]) + tk.MustExec("admin evolve bindings") + rows = tk.MustQuery("show global bindings").Rows() + require.Len(t, rows, 2) + require.Equal(t, "SELECT /*+ use_index(@`sel_1` `test`.`t` )*/ * FROM `test`.`t` WHERE `a` >= 4 AND `b` >= 1 AND `c` = 0", rows[0][1]) + status := rows[0][3].(string) + require.True(t, status == "using" || status == "rejected") +} + +func TestRuntimeHintsInEvolveTasks(t *testing.T) { + originalVal := config.CheckTableBeforeDrop + config.CheckTableBeforeDrop = true + defer func() { + config.CheckTableBeforeDrop = originalVal + }() + + store, clean := testkit.CreateMockStore(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("set @@tidb_evolve_plan_baselines=1") + tk.MustExec("create table t(a int, b int, c int, index idx_a(a), index idx_b(b), index idx_c(c))") + + tk.MustExec("create global binding for select * from t where a >= 1 and b >= 1 and c = 0 using select * from t use index(idx_a) where a >= 1 and b >= 1 and c = 0") + tk.MustQuery("select /*+ MAX_EXECUTION_TIME(5000) */ * from t where a >= 4 and b >= 1 and c = 0") + tk.MustExec("admin flush bindings") + rows := tk.MustQuery("show global bindings").Rows() + require.Len(t, rows, 2) + require.Equal(t, "SELECT /*+ use_index(@`sel_1` `test`.`t` `idx_c`), max_execution_time(5000)*/ * FROM `test`.`t` WHERE `a` >= 4 AND `b` >= 1 AND `c` = 0", rows[0][1]) +} + +func TestDefaultSessionVars(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + tk.MustQuery(`show variables like "%baselines%"`).Sort().Check(testkit.Rows( + "tidb_capture_plan_baselines OFF", + "tidb_evolve_plan_baselines OFF", + "tidb_use_plan_baselines ON")) + tk.MustQuery(`show global variables like "%baselines%"`).Sort().Check(testkit.Rows( + "tidb_capture_plan_baselines OFF", + "tidb_evolve_plan_baselines OFF", + "tidb_use_plan_baselines ON")) +} + +func TestCaptureBaselinesScope(t *testing.T) { + store, dom, clean := testkit.CreateMockStoreAndDomain(t) + defer clean() + + tk1 := testkit.NewTestKit(t, store) + tk2 := testkit.NewTestKit(t, store) + + utilCleanBindingEnv(tk1, dom) + tk1.MustQuery(`show session variables like "tidb_capture_plan_baselines"`).Check(testkit.Rows( + "tidb_capture_plan_baselines OFF", + )) + tk1.MustQuery(`show global variables like "tidb_capture_plan_baselines"`).Check(testkit.Rows( + "tidb_capture_plan_baselines OFF", + )) + tk1.MustQuery(`select @@session.tidb_capture_plan_baselines`).Check(testkit.Rows( + "0", + )) + tk1.MustQuery(`select @@global.tidb_capture_plan_baselines`).Check(testkit.Rows( + "0", + )) + + tk1.MustExec("set @@session.tidb_capture_plan_baselines = on") + defer func() { + tk1.MustExec(" set @@session.tidb_capture_plan_baselines = off") + }() + tk1.MustQuery(`show session variables like "tidb_capture_plan_baselines"`).Check(testkit.Rows( + "tidb_capture_plan_baselines ON", + )) + tk1.MustQuery(`show global variables like "tidb_capture_plan_baselines"`).Check(testkit.Rows( + "tidb_capture_plan_baselines OFF", + )) + tk1.MustQuery(`select @@session.tidb_capture_plan_baselines`).Check(testkit.Rows( + "1", + )) + tk1.MustQuery(`select @@global.tidb_capture_plan_baselines`).Check(testkit.Rows( + "0", + )) + tk2.MustQuery(`show session variables like "tidb_capture_plan_baselines"`).Check(testkit.Rows( + "tidb_capture_plan_baselines ON", + )) + tk2.MustQuery(`show global variables like "tidb_capture_plan_baselines"`).Check(testkit.Rows( + "tidb_capture_plan_baselines OFF", + )) + tk2.MustQuery(`select @@session.tidb_capture_plan_baselines`).Check(testkit.Rows( + "1", + )) + tk2.MustQuery(`select @@global.tidb_capture_plan_baselines`).Check(testkit.Rows( + "0", + )) +} + +func TestStmtHints(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int, b int, index idx(a))") + tk.MustExec("create global binding for select * from t using select /*+ MAX_EXECUTION_TIME(100), MEMORY_QUOTA(1 GB) */ * from t use index(idx)") + tk.MustQuery("select * from t") + require.Equal(t, int64(1073741824), tk.Session().GetSessionVars().StmtCtx.MemQuotaQuery) + require.Equal(t, uint64(100), tk.Session().GetSessionVars().StmtCtx.MaxExecutionTime) + tk.MustQuery("select a, b from t") + require.Equal(t, int64(0), tk.Session().GetSessionVars().StmtCtx.MemQuotaQuery) + require.Equal(t, uint64(0), tk.Session().GetSessionVars().StmtCtx.MaxExecutionTime) +} + +func TestPrivileges(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int, b int, index idx(a))") + tk.MustExec("create global binding for select * from t using select * from t use index(idx)") + require.True(t, tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "%"}, nil, nil)) + rows := tk.MustQuery("show global bindings").Rows() + require.Len(t, rows, 1) + tk.MustExec("create user test@'%'") + require.True(t, tk.Session().Auth(&auth.UserIdentity{Username: "test", Hostname: "%"}, nil, nil)) + rows = tk.MustQuery("show global bindings").Rows() + require.Len(t, rows, 0) +} + +func TestHintsSetEvolveTask(t *testing.T) { + originalVal := config.CheckTableBeforeDrop + config.CheckTableBeforeDrop = true + defer func() { + config.CheckTableBeforeDrop = originalVal + }() + + store, dom, clean := testkit.CreateMockStoreAndDomain(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int, index idx_a(a))") + tk.MustExec("create global binding for select * from t where a > 10 using select * from t ignore index(idx_a) where a > 10") + tk.MustExec("set @@tidb_evolve_plan_baselines=1") + tk.MustQuery("select * from t use index(idx_a) where a > 0") + bindHandle := dom.BindHandle() + bindHandle.SaveEvolveTasksToStore() + // Verify the added Binding for evolution contains valid ID and Hint, otherwise, panic may happen. + sql, hash := utilNormalizeWithDefaultDB(t, "select * from t where a > ?", "test") + bindData := bindHandle.GetBindRecord(hash, sql, "test") + require.NotNil(t, bindData) + require.Equal(t, "select * from `test` . `t` where `a` > ?", bindData.OriginalSQL) + require.Len(t, bindData.Bindings, 2) + bind := bindData.Bindings[1] + require.Equal(t, bindinfo.PendingVerify, bind.Status) + require.NotEqual(t, "", bind.ID) + require.NotNil(t, bind.Hint) +} + +func TestHintsSetID(t *testing.T) { + store, dom, clean := testkit.CreateMockStoreAndDomain(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int, index idx_a(a))") + tk.MustExec("create global binding for select * from t where a > 10 using select /*+ use_index(test.t, idx_a) */ * from t where a > 10") + bindHandle := dom.BindHandle() + // Verify the added Binding contains ID with restored query block. + sql, hash := utilNormalizeWithDefaultDB(t, "select * from t where a > ?", "test") + bindData := bindHandle.GetBindRecord(hash, sql, "test") + require.NotNil(t, bindData) + require.Equal(t, "select * from `test` . `t` where `a` > ?", bindData.OriginalSQL) + require.Len(t, bindData.Bindings, 1) + bind := bindData.Bindings[0] + require.Equal(t, "use_index(@`sel_1` `test`.`t` `idx_a`)", bind.ID) + + utilCleanBindingEnv(tk, dom) + tk.MustExec("create global binding for select * from t where a > 10 using select /*+ use_index(t, idx_a) */ * from t where a > 10") + bindData = bindHandle.GetBindRecord(hash, sql, "test") + require.NotNil(t, bindData) + require.Equal(t, "select * from `test` . `t` where `a` > ?", bindData.OriginalSQL) + require.Len(t, bindData.Bindings, 1) + bind = bindData.Bindings[0] + require.Equal(t, "use_index(@`sel_1` `test`.`t` `idx_a`)", bind.ID) + + utilCleanBindingEnv(tk, dom) + tk.MustExec("create global binding for select * from t where a > 10 using select /*+ use_index(@sel_1 t, idx_a) */ * from t where a > 10") + bindData = bindHandle.GetBindRecord(hash, sql, "test") + require.NotNil(t, bindData) + require.Equal(t, "select * from `test` . `t` where `a` > ?", bindData.OriginalSQL) + require.Len(t, bindData.Bindings, 1) + bind = bindData.Bindings[0] + require.Equal(t, "use_index(@`sel_1` `test`.`t` `idx_a`)", bind.ID) + + utilCleanBindingEnv(tk, dom) + tk.MustExec("create global binding for select * from t where a > 10 using select /*+ use_index(@qb1 t, idx_a) qb_name(qb1) */ * from t where a > 10") + bindData = bindHandle.GetBindRecord(hash, sql, "test") + require.NotNil(t, bindData) + require.Equal(t, "select * from `test` . `t` where `a` > ?", bindData.OriginalSQL) + require.Len(t, bindData.Bindings, 1) + bind = bindData.Bindings[0] + require.Equal(t, "use_index(@`sel_1` `test`.`t` `idx_a`)", bind.ID) + + utilCleanBindingEnv(tk, dom) + tk.MustExec("create global binding for select * from t where a > 10 using select /*+ use_index(T, IDX_A) */ * from t where a > 10") + bindData = bindHandle.GetBindRecord(hash, sql, "test") + require.NotNil(t, bindData) + require.Equal(t, "select * from `test` . `t` where `a` > ?", bindData.OriginalSQL) + require.Len(t, bindData.Bindings, 1) + bind = bindData.Bindings[0] + require.Equal(t, "use_index(@`sel_1` `test`.`t` `idx_a`)", bind.ID) + + utilCleanBindingEnv(tk, dom) + err := tk.ExecToErr("create global binding for select * from t using select /*+ non_exist_hint() */ * from t") + require.True(t, terror.ErrorEqual(err, parser.ErrWarnOptimizerHintParseError)) + tk.MustExec("create global binding for select * from t where a > 10 using select * from t where a > 10") + bindData = bindHandle.GetBindRecord(hash, sql, "test") + require.NotNil(t, bindData) + require.Equal(t, "select * from `test` . `t` where `a` > ?", bindData.OriginalSQL) + require.Len(t, bindData.Bindings, 1) + bind = bindData.Bindings[0] + require.Equal(t, "", bind.ID) +} + +func TestNotEvolvePlanForReadStorageHint(t *testing.T) { + originalVal := config.CheckTableBeforeDrop + config.CheckTableBeforeDrop = true + defer func() { + config.CheckTableBeforeDrop = originalVal + }() + + store, clean := testkit.CreateMockStore(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int, b int, index idx_a(a), index idx_b(b))") + tk.MustExec("insert into t values (1,1), (2,2), (3,3), (4,4), (5,5), (6,6), (7,7), (8,8), (9,9), (10,10)") + tk.MustExec("analyze table t") + // Create virtual tiflash replica info. + dom := domain.GetDomain(tk.Session()) + is := dom.InfoSchema() + db, exists := is.SchemaByName(model.NewCIStr("test")) + require.True(t, exists) + for _, tblInfo := range db.Tables { + if tblInfo.Name.L == "t" { + tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{ + Count: 1, + Available: true, + } + } + } + + // Make sure the best plan of the SQL is use TiKV index. + tk.MustExec("set @@session.tidb_executor_concurrency = 4;") + rows := tk.MustQuery("explain select * from t where a >= 11 and b >= 11").Rows() + require.Equal(t, "cop[tikv]", fmt.Sprintf("%v", rows[len(rows)-1][2])) + + tk.MustExec("create global binding for select * from t where a >= 1 and b >= 1 using select /*+ read_from_storage(tiflash[t]) */ * from t where a >= 1 and b >= 1") + tk.MustExec("set @@tidb_evolve_plan_baselines=1") + + // Even if index of TiKV has lower cost, it chooses TiFlash. + rows = tk.MustQuery("explain select * from t where a >= 11 and b >= 11").Rows() + require.Equal(t, "cop[tiflash]", fmt.Sprintf("%v", rows[len(rows)-1][2])) + + tk.MustExec("admin flush bindings") + rows = tk.MustQuery("show global bindings").Rows() + // None evolve task, because of the origin binding is a read_from_storage binding. + require.Len(t, rows, 1) + require.Equal(t, "SELECT /*+ read_from_storage(tiflash[`t`])*/ * FROM `test`.`t` WHERE `a` >= 1 AND `b` >= 1", rows[0][1]) + require.Equal(t, "using", rows[0][3]) +} + +func TestBindingWithIsolationRead(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int, b int, index idx_a(a), index idx_b(b))") + tk.MustExec("insert into t values (1,1), (2,2), (3,3), (4,4), (5,5), (6,6), (7,7), (8,8), (9,9), (10,10)") + tk.MustExec("analyze table t") + // Create virtual tiflash replica info. + dom := domain.GetDomain(tk.Session()) + is := dom.InfoSchema() + db, exists := is.SchemaByName(model.NewCIStr("test")) + require.True(t, exists) + for _, tblInfo := range db.Tables { + if tblInfo.Name.L == "t" { + tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{ + Count: 1, + Available: true, + } + } + } + tk.MustExec("create global binding for select * from t where a >= 1 and b >= 1 using select * from t use index(idx_a) where a >= 1 and b >= 1") + tk.MustExec("set @@tidb_use_plan_baselines = 1") + rows := tk.MustQuery("explain select * from t where a >= 11 and b >= 11").Rows() + require.Equal(t, "cop[tikv]", rows[len(rows)-1][2]) + // Even if we build a binding use index for SQL, but after we set the isolation read for TiFlash, it choose TiFlash instead of index of TiKV. + tk.MustExec("set @@tidb_isolation_read_engines = \"tiflash\"") + rows = tk.MustQuery("explain select * from t where a >= 11 and b >= 11").Rows() + require.Equal(t, "cop[tiflash]", rows[len(rows)-1][2]) +} + +func TestReCreateBindAfterEvolvePlan(t *testing.T) { + originalVal := config.CheckTableBeforeDrop + config.CheckTableBeforeDrop = true + defer func() { + config.CheckTableBeforeDrop = originalVal + }() + + store, clean := testkit.CreateMockStore(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int, b int, c int, index idx_a(a), index idx_b(b), index idx_c(c))") + tk.MustExec("insert into t values (1,1,1), (2,2,2), (3,3,3), (4,4,4), (5,5,5)") + tk.MustExec("analyze table t") + tk.MustExec("create global binding for select * from t where a >= 1 and b >= 1 using select * from t use index(idx_a) where a >= 1 and b >= 1") + tk.MustExec("set @@tidb_evolve_plan_baselines=1") + + // It cannot choose table path although it has lowest cost. + tk.MustQuery("select * from t where a >= 0 and b >= 0") + require.Equal(t, "t:idx_a", tk.Session().GetSessionVars().StmtCtx.IndexNames[0]) + + tk.MustExec("admin flush bindings") + rows := tk.MustQuery("show global bindings").Rows() + require.Len(t, rows, 2) + require.Equal(t, "SELECT /*+ use_index(@`sel_1` `test`.`t` )*/ * FROM `test`.`t` WHERE `a` >= 0 AND `b` >= 0", rows[0][1]) + require.Equal(t, "pending verify", rows[0][3]) + + tk.MustExec("create global binding for select * from t where a >= 1 and b >= 1 using select * from t use index(idx_b) where a >= 1 and b >= 1") + rows = tk.MustQuery("show global bindings").Rows() + require.Len(t, rows, 1) + tk.MustQuery("select * from t where a >= 4 and b >= 1") + require.Equal(t, "t:idx_b", tk.Session().GetSessionVars().StmtCtx.IndexNames[0]) +} + +func TestInvisibleIndex(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int, b int, unique idx_a(a), index idx_b(b) invisible)") + tk.MustGetErrMsg( + "create global binding for select * from t using select * from t use index(idx_b) ", + "[planner:1176]Key 'idx_b' doesn't exist in table 't'") + + // Create bind using index + tk.MustExec("create global binding for select * from t using select * from t use index(idx_a) ") + + tk.MustQuery("select * from t") + require.Equal(t, "t:idx_a", tk.Session().GetSessionVars().StmtCtx.IndexNames[0]) + require.True(t, tk.MustUseIndex("select * from t", "idx_a(a)")) + + tk.MustExec(`prepare stmt1 from 'select * from t'`) + tk.MustExec("execute stmt1") + require.Len(t, tk.Session().GetSessionVars().StmtCtx.IndexNames, 1) + require.Equal(t, "t:idx_a", tk.Session().GetSessionVars().StmtCtx.IndexNames[0]) + + // And then make this index invisible + tk.MustExec("alter table t alter index idx_a invisible") + tk.MustQuery("select * from t") + require.Len(t, tk.Session().GetSessionVars().StmtCtx.IndexNames, 0) + + tk.MustExec("execute stmt1") + require.Len(t, tk.Session().GetSessionVars().StmtCtx.IndexNames, 0) + + tk.MustExec("drop binding for select * from t") +} + +func TestSPMHitInfo(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1") + tk.MustExec("drop table if exists t2") + tk.MustExec("create table t1(id int)") + tk.MustExec("create table t2(id int)") + + require.True(t, tk.HasPlan("SELECT * from t1,t2 where t1.id = t2.id", "HashJoin")) + require.True(t, tk.HasPlan("SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id", "MergeJoin")) + + tk.MustExec("SELECT * from t1,t2 where t1.id = t2.id") + tk.MustQuery(`select @@last_plan_from_binding;`).Check(testkit.Rows("0")) + tk.MustExec("create global binding for SELECT * from t1,t2 where t1.id = t2.id using SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id") + + require.True(t, tk.HasPlan("SELECT * from t1,t2 where t1.id = t2.id", "MergeJoin")) + tk.MustExec("SELECT * from t1,t2 where t1.id = t2.id") + tk.MustQuery(`select @@last_plan_from_binding;`).Check(testkit.Rows("1")) + tk.MustExec("drop global binding for SELECT * from t1,t2 where t1.id = t2.id") +} + +func TestReCreateBind(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int, b int, index idx(a))") + + tk.MustQuery("select * from mysql.bind_info where source != 'builtin'").Check(testkit.Rows()) + tk.MustQuery("show global bindings").Check(testkit.Rows()) + + tk.MustExec("create global binding for select * from t using select * from t") + tk.MustQuery("select original_sql, status from mysql.bind_info where source != 'builtin';").Check(testkit.Rows( + "select * from `test` . `t` using", + )) + rows := tk.MustQuery("show global bindings").Rows() + require.Len(t, rows, 1) + require.Equal(t, "select * from `test` . `t`", rows[0][0]) + require.Equal(t, "using", rows[0][3]) + + tk.MustExec("create global binding for select * from t using select * from t") + rows = tk.MustQuery("show global bindings").Rows() + require.Len(t, rows, 1) + require.Equal(t, "select * from `test` . `t`", rows[0][0]) + require.Equal(t, "using", rows[0][3]) + + rows = tk.MustQuery("select original_sql, status from mysql.bind_info where source != 'builtin';").Rows() + require.Len(t, rows, 2) + require.Equal(t, "deleted", rows[0][1]) + require.Equal(t, "using", rows[1][1]) +} + +func TestExplainShowBindSQL(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int, b int, key(a))") + + tk.MustExec("create global binding for select * from t using select * from t use index(a)") + tk.MustQuery("select original_sql, bind_sql from mysql.bind_info where default_db != 'mysql'").Check(testkit.Rows( + "select * from `test` . `t` SELECT * FROM `test`.`t` USE INDEX (`a`)", + )) + + tk.MustExec("explain format = 'verbose' select * from t") + tk.MustQuery("show warnings").Check(testkit.Rows("Note 1105 Using the bindSQL: SELECT * FROM `test`.`t` USE INDEX (`a`)")) + // explain analyze do not support verbose yet. +} + +func TestDMLIndexHintBind(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table t(a int, b int, c int, key idx_b(b), key idx_c(c))") + + tk.MustExec("delete from t where b = 1 and c > 1") + require.Equal(t, "t:idx_b", tk.Session().GetSessionVars().StmtCtx.IndexNames[0]) + require.True(t, tk.MustUseIndex("delete from t where b = 1 and c > 1", "idx_b(b)")) + tk.MustExec("create global binding for delete from t where b = 1 and c > 1 using delete from t use index(idx_c) where b = 1 and c > 1") + tk.MustExec("delete from t where b = 1 and c > 1") + require.Equal(t, "t:idx_c", tk.Session().GetSessionVars().StmtCtx.IndexNames[0]) + require.True(t, tk.MustUseIndex("delete from t where b = 1 and c > 1", "idx_c(c)")) +} + +func TestForbidEvolvePlanBaseLinesBeforeGA(t *testing.T) { + originalVal := config.CheckTableBeforeDrop + config.CheckTableBeforeDrop = false + defer func() { + config.CheckTableBeforeDrop = originalVal + }() + + store, clean := testkit.CreateMockStore(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + err := tk.ExecToErr("set @@tidb_evolve_plan_baselines=0") + require.Equal(t, nil, err) + err = tk.ExecToErr("set @@TiDB_Evolve_pLan_baselines=1") + require.Regexp(t, "Cannot enable baseline evolution feature, it is not generally available now", err) + err = tk.ExecToErr("set @@TiDB_Evolve_pLan_baselines=oN") + require.Regexp(t, "Cannot enable baseline evolution feature, it is not generally available now", err) + err = tk.ExecToErr("admin evolve bindings") + require.Regexp(t, "Cannot enable baseline evolution feature, it is not generally available now", err) +} + +func TestExplainTableStmts(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(id int, value decimal(5,2))") + tk.MustExec("table t") + tk.MustExec("explain table t") + tk.MustExec("desc table t") +} + +func TestSPMWithoutUseDatabase(t *testing.T) { + store, dom, clean := testkit.CreateMockStoreAndDomain(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + tk1 := testkit.NewTestKit(t, store) + utilCleanBindingEnv(tk, dom) + utilCleanBindingEnv(tk1, dom) + + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int, b int, key(a))") + tk.MustExec("create global binding for select * from t using select * from t force index(a)") + + err := tk1.ExecToErr("select * from t") + require.Regexp(t, ".*No database selected", err) + tk1.MustQuery(`select @@last_plan_from_binding;`).Check(testkit.Rows("0")) + require.True(t, tk1.MustUseIndex("select * from test.t", "a")) + tk1.MustExec("select * from test.t") + tk1.MustQuery(`select @@last_plan_from_binding;`).Check(testkit.Rows("1")) +} + +func TestBindingWithoutCharset(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a varchar(10) CHARACTER SET utf8)") + tk.MustExec("create global binding for select * from t where a = 'aa' using select * from t where a = 'aa'") + rows := tk.MustQuery("show global bindings").Rows() + require.Len(t, rows, 1) + require.Equal(t, "select * from `test` . `t` where `a` = ?", rows[0][0]) + require.Equal(t, "SELECT * FROM `test`.`t` WHERE `a` = 'aa'", rows[0][1]) +} + +func TestGCBindRecord(t *testing.T) { + store, dom, clean := testkit.CreateMockStoreAndDomain(t) + defer clean() + + // set lease for gc tests + originLease := bindinfo.Lease + bindinfo.Lease = 0 + + defer func() { + bindinfo.Lease = originLease + }() + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int, b int, key(a))") + + tk.MustExec("create global binding for select * from t where a = 1 using select * from t use index(a) where a = 1") + rows := tk.MustQuery("show global bindings").Rows() + require.Len(t, rows, 1) + require.Equal(t, "select * from `test` . `t` where `a` = ?", rows[0][0]) + require.Equal(t, "using", rows[0][3]) + tk.MustQuery("select status from mysql.bind_info where original_sql = 'select * from `test` . `t` where `a` = ?'").Check(testkit.Rows( + "using", + )) + + h := dom.BindHandle() + // bindinfo.Lease is set to 0 for test env in SetUpSuite. + require.NoError(t, h.GCBindRecord()) + rows = tk.MustQuery("show global bindings").Rows() + require.Len(t, rows, 1) + require.Equal(t, "select * from `test` . `t` where `a` = ?", rows[0][0]) + require.Equal(t, "using", rows[0][3]) + tk.MustQuery("select status from mysql.bind_info where original_sql = 'select * from `test` . `t` where `a` = ?'").Check(testkit.Rows( + "using", + )) + + tk.MustExec("drop global binding for select * from t where a = 1") + tk.MustQuery("show global bindings").Check(testkit.Rows()) + tk.MustQuery("select status from mysql.bind_info where original_sql = 'select * from `test` . `t` where `a` = ?'").Check(testkit.Rows( + "deleted", + )) + require.NoError(t, h.GCBindRecord()) + tk.MustQuery("show global bindings").Check(testkit.Rows()) + tk.MustQuery("select status from mysql.bind_info where original_sql = 'select * from `test` . `t` where `a` = ?'").Check(testkit.Rows()) +} diff --git a/bindinfo/bind_test.go b/bindinfo/bind_test.go deleted file mode 100644 index 6167b7bd600b6..0000000000000 --- a/bindinfo/bind_test.go +++ /dev/null @@ -1,1110 +0,0 @@ -// Copyright 2019 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package bindinfo_test - -import ( - "context" - "crypto/tls" - "flag" - "fmt" - "os" - "testing" - - . "github.com/pingcap/check" - "github.com/pingcap/parser" - "github.com/pingcap/parser/auth" - "github.com/pingcap/parser/model" - "github.com/pingcap/parser/terror" - "github.com/pingcap/tidb/bindinfo" - "github.com/pingcap/tidb/config" - "github.com/pingcap/tidb/domain" - "github.com/pingcap/tidb/kv" - "github.com/pingcap/tidb/meta/autoid" - "github.com/pingcap/tidb/session" - "github.com/pingcap/tidb/session/txninfo" - "github.com/pingcap/tidb/store/mockstore" - "github.com/pingcap/tidb/util" - "github.com/pingcap/tidb/util/logutil" - utilparser "github.com/pingcap/tidb/util/parser" - "github.com/pingcap/tidb/util/testkit" - "github.com/pingcap/tidb/util/testleak" - "github.com/tikv/client-go/v2/testutils" -) - -func TestT(t *testing.T) { - CustomVerboseFlag = true - logLevel := os.Getenv("log_level") - err := logutil.InitLogger(logutil.NewLogConfig(logLevel, logutil.DefaultLogFormat, "", logutil.EmptyFileLogConfig, false)) - if err != nil { - t.Fatal(err) - } - autoid.SetStep(5000) - TestingT(t) -} - -var _ = Suite(&testSuite{}) - -type testSuite struct { - cluster testutils.Cluster - store kv.Storage - domain *domain.Domain - *parser.Parser -} - -type mockSessionManager struct { - PS []*util.ProcessInfo -} - -func (msm *mockSessionManager) ShowTxnList() []*txninfo.TxnInfo { - panic("unimplemented!") -} - -func (msm *mockSessionManager) ShowProcessList() map[uint64]*util.ProcessInfo { - ret := make(map[uint64]*util.ProcessInfo) - for _, item := range msm.PS { - ret[item.ID] = item - } - return ret -} - -func (msm *mockSessionManager) GetProcessInfo(id uint64) (*util.ProcessInfo, bool) { - for _, item := range msm.PS { - if item.ID == id { - return item, true - } - } - return &util.ProcessInfo{}, false -} - -func (msm *mockSessionManager) Kill(cid uint64, query bool) { -} - -func (msm *mockSessionManager) KillAllConnections() { -} - -func (msm *mockSessionManager) UpdateTLSConfig(cfg *tls.Config) { -} - -func (msm *mockSessionManager) ServerID() uint64 { - return 1 -} - -var mockTikv = flag.Bool("mockTikv", true, "use mock tikv store in bind test") - -func (s *testSuite) SetUpSuite(c *C) { - testleak.BeforeTest() - s.Parser = parser.New() - flag.Lookup("mockTikv") - useMockTikv := *mockTikv - if useMockTikv { - store, err := mockstore.NewMockStore( - mockstore.WithClusterInspector(func(c testutils.Cluster) { - mockstore.BootstrapWithSingleStore(c) - s.cluster = c - }), - ) - c.Assert(err, IsNil) - s.store = store - session.SetSchemaLease(0) - session.DisableStats4Test() - } - bindinfo.Lease = 0 - d, err := session.BootstrapSession(s.store) - c.Assert(err, IsNil) - d.SetStatsUpdating(true) - s.domain = d -} - -func (s *testSuite) TearDownSuite(c *C) { - s.domain.Close() - s.store.Close() - testleak.AfterTest(c)() -} - -func (s *testSuite) TearDownTest(c *C) { - tk := testkit.NewTestKit(c, s.store) - tk.MustExec("use test") - r := tk.MustQuery("show tables") - for _, tb := range r.Rows() { - tableName := tb[0] - tk.MustExec(fmt.Sprintf("drop table %v", tableName)) - } -} - -func (s *testSuite) cleanBindingEnv(tk *testkit.TestKit) { - tk.MustExec("delete from mysql.bind_info where source != 'builtin'") - s.domain.BindHandle().Clear() -} - -func normalizeWithDefaultDB(c *C, sql, db string) (string, string) { - testParser := parser.New() - stmt, err := testParser.ParseOneStmt(sql, "", "") - c.Assert(err, IsNil) - normalized, digest := parser.NormalizeDigest(utilparser.RestoreWithDefaultDB(stmt, "test", "")) - return normalized, digest.String() -} - -var testSQLs = []struct { - createSQL string - overlaySQL string - querySQL string - originSQL string - bindSQL string - dropSQL string - memoryUsage float64 -}{ - { - createSQL: "binding for select * from t where i>100 using select * from t use index(index_t) where i>100", - overlaySQL: "binding for select * from t where i>99 using select * from t use index(index_t) where i>99", - querySQL: "select * from t where i > 30.0", - originSQL: "select * from `test` . `t` where `i` > ?", - bindSQL: "SELECT * FROM `test`.`t` USE INDEX (`index_t`) WHERE `i` > 99", - dropSQL: "binding for select * from t where i>100", - memoryUsage: float64(144), - }, - { - createSQL: "binding for select * from t union all select * from t using select * from t use index(index_t) union all select * from t use index()", - overlaySQL: "", - querySQL: "select * from t union all select * from t", - originSQL: "select * from `test` . `t` union all select * from `test` . `t`", - bindSQL: "SELECT * FROM `test`.`t` USE INDEX (`index_t`) UNION ALL SELECT * FROM `test`.`t` USE INDEX ()", - dropSQL: "binding for select * from t union all select * from t", - memoryUsage: float64(200), - }, - { - createSQL: "binding for (select * from t) union all (select * from t) using (select * from t use index(index_t)) union all (select * from t use index())", - overlaySQL: "", - querySQL: "(select * from t) union all (select * from t)", - originSQL: "( select * from `test` . `t` ) union all ( select * from `test` . `t` )", - bindSQL: "(SELECT * FROM `test`.`t` USE INDEX (`index_t`)) UNION ALL (SELECT * FROM `test`.`t` USE INDEX ())", - dropSQL: "binding for (select * from t) union all (select * from t)", - memoryUsage: float64(212), - }, - { - createSQL: "binding for select * from t intersect select * from t using select * from t use index(index_t) intersect select * from t use index()", - overlaySQL: "", - querySQL: "select * from t intersect select * from t", - originSQL: "select * from `test` . `t` intersect select * from `test` . `t`", - bindSQL: "SELECT * FROM `test`.`t` USE INDEX (`index_t`) INTERSECT SELECT * FROM `test`.`t` USE INDEX ()", - dropSQL: "binding for select * from t intersect select * from t", - memoryUsage: float64(200), - }, - { - createSQL: "binding for select * from t except select * from t using select * from t use index(index_t) except select * from t use index()", - overlaySQL: "", - querySQL: "select * from t except select * from t", - originSQL: "select * from `test` . `t` except select * from `test` . `t`", - bindSQL: "SELECT * FROM `test`.`t` USE INDEX (`index_t`) EXCEPT SELECT * FROM `test`.`t` USE INDEX ()", - dropSQL: "binding for select * from t except select * from t", - memoryUsage: float64(194), - }, - { - createSQL: "binding for select * from t using select /*+ use_index(t,index_t)*/ * from t", - overlaySQL: "", - querySQL: "select * from t ", - originSQL: "select * from `test` . `t`", - bindSQL: "SELECT /*+ use_index(`t` `index_t`)*/ * FROM `test`.`t`", - dropSQL: "binding for select * from t", - memoryUsage: float64(124), - }, - { - createSQL: "binding for delete from t where i = 1 using delete /*+ use_index(t,index_t) */ from t where i = 1", - overlaySQL: "", - querySQL: "delete from t where i = 2", - originSQL: "delete from `test` . `t` where `i` = ?", - bindSQL: "DELETE /*+ use_index(`t` `index_t`)*/ FROM `test`.`t` WHERE `i` = 1", - dropSQL: "binding for delete from t where i = 1", - memoryUsage: float64(148), - }, - { - createSQL: "binding for delete t, t1 from t inner join t1 on t.s = t1.s where t.i = 1 using delete /*+ use_index(t,index_t), hash_join(t,t1) */ t, t1 from t inner join t1 on t.s = t1.s where t.i = 1", - overlaySQL: "", - querySQL: "delete t, t1 from t inner join t1 on t.s = t1.s where t.i = 2", - originSQL: "delete `test` . `t` , `test` . `t1` from `test` . `t` join `test` . `t1` on `t` . `s` = `t1` . `s` where `t` . `i` = ?", - bindSQL: "DELETE /*+ use_index(`t` `index_t`) hash_join(`t`, `t1`)*/ `test`.`t`,`test`.`t1` FROM `test`.`t` JOIN `test`.`t1` ON `t`.`s` = `t1`.`s` WHERE `t`.`i` = 1", - dropSQL: "binding for delete t, t1 from t inner join t1 on t.s = t1.s where t.i = 1", - memoryUsage: float64(315), - }, - { - createSQL: "binding for update t set s = 'a' where i = 1 using update /*+ use_index(t,index_t) */ t set s = 'a' where i = 1", - overlaySQL: "", - querySQL: "update t set s='b' where i=2", - originSQL: "update `test` . `t` set `s` = ? where `i` = ?", - bindSQL: "UPDATE /*+ use_index(`t` `index_t`)*/ `test`.`t` SET `s`='a' WHERE `i` = 1", - dropSQL: "binding for update t set s = 'a' where i = 1", - memoryUsage: float64(162), - }, - { - createSQL: "binding for update t, t1 set t.s = 'a' where t.i = t1.i using update /*+ inl_join(t1) */ t, t1 set t.s = 'a' where t.i = t1.i", - overlaySQL: "", - querySQL: "update t , t1 set t.s='b' where t.i=t1.i", - originSQL: "update ( `test` . `t` ) join `test` . `t1` set `t` . `s` = ? where `t` . `i` = `t1` . `i`", - bindSQL: "UPDATE /*+ inl_join(`t1`)*/ (`test`.`t`) JOIN `test`.`t1` SET `t`.`s`='a' WHERE `t`.`i` = `t1`.`i`", - dropSQL: "binding for update t, t1 set t.s = 'a' where t.i = t1.i", - memoryUsage: float64(230), - }, - { - createSQL: "binding for insert into t1 select * from t where t.i = 1 using insert into t1 select /*+ use_index(t,index_t) */ * from t where t.i = 1", - overlaySQL: "", - querySQL: "insert into t1 select * from t where t.i = 2", - originSQL: "insert into `test` . `t1` select * from `test` . `t` where `t` . `i` = ?", - bindSQL: "INSERT INTO `test`.`t1` SELECT /*+ use_index(`t` `index_t`)*/ * FROM `test`.`t` WHERE `t`.`i` = 1", - dropSQL: "binding for insert into t1 select * from t where t.i = 1", - memoryUsage: float64(212), - }, - { - createSQL: "binding for replace into t1 select * from t where t.i = 1 using replace into t1 select /*+ use_index(t,index_t) */ * from t where t.i = 1", - overlaySQL: "", - querySQL: "replace into t1 select * from t where t.i = 2", - originSQL: "replace into `test` . `t1` select * from `test` . `t` where `t` . `i` = ?", - bindSQL: "REPLACE INTO `test`.`t1` SELECT /*+ use_index(`t` `index_t`)*/ * FROM `test`.`t` WHERE `t`.`i` = 1", - dropSQL: "binding for replace into t1 select * from t where t.i = 1", - memoryUsage: float64(214), - }, -} - -func (s *testSuite) TestExplain(c *C) { - tk := testkit.NewTestKit(c, s.store) - s.cleanBindingEnv(tk) - tk.MustExec("use test") - tk.MustExec("drop table if exists t1") - tk.MustExec("drop table if exists t2") - tk.MustExec("create table t1(id int)") - tk.MustExec("create table t2(id int)") - - c.Assert(tk.HasPlan("SELECT * from t1,t2 where t1.id = t2.id", "HashJoin"), IsTrue) - c.Assert(tk.HasPlan("SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id", "MergeJoin"), IsTrue) - - tk.MustExec("create global binding for SELECT * from t1,t2 where t1.id = t2.id using SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id") - - c.Assert(tk.HasPlan("SELECT * from t1,t2 where t1.id = t2.id", "MergeJoin"), IsTrue) - - tk.MustExec("drop global binding for SELECT * from t1,t2 where t1.id = t2.id") - - // Add test for SetOprStmt - tk.MustExec("create index index_id on t1(id)") - c.Assert(tk.HasPlan("SELECT * from t1 union SELECT * from t1", "IndexReader"), IsFalse) - c.Assert(tk.HasPlan("SELECT * from t1 use index(index_id) union SELECT * from t1", "IndexReader"), IsTrue) - - tk.MustExec("create global binding for SELECT * from t1 union SELECT * from t1 using SELECT * from t1 use index(index_id) union SELECT * from t1") - - c.Assert(tk.HasPlan("SELECT * from t1 union SELECT * from t1", "IndexReader"), IsTrue) - - tk.MustExec("drop global binding for SELECT * from t1 union SELECT * from t1") -} - -// TestBindingSymbolList tests sql with "?, ?, ?, ?", fixes #13871 -func (s *testSuite) TestBindingSymbolList(c *C) { - tk := testkit.NewTestKit(c, s.store) - s.cleanBindingEnv(tk) - tk.MustExec("use test") - tk.MustExec("drop table if exists t") - tk.MustExec("create table t(a int, b int, INDEX ia (a), INDEX ib (b));") - tk.MustExec("insert into t value(1, 1);") - - // before binding - tk.MustQuery("select a, b from t where a = 3 limit 1, 100") - c.Assert(tk.Se.GetSessionVars().StmtCtx.IndexNames[0], Equals, "t:ia") - c.Assert(tk.MustUseIndex("select a, b from t where a = 3 limit 1, 100", "ia(a)"), IsTrue) - - tk.MustExec(`create global binding for select a, b from t where a = 1 limit 0, 1 using select a, b from t use index (ib) where a = 1 limit 0, 1`) - - // after binding - tk.MustQuery("select a, b from t where a = 3 limit 1, 100") - c.Assert(tk.Se.GetSessionVars().StmtCtx.IndexNames[0], Equals, "t:ib") - c.Assert(tk.MustUseIndex("select a, b from t where a = 3 limit 1, 100", "ib(b)"), IsTrue) - - // Normalize - sql, hash := parser.NormalizeDigest("select a, b from test . t where a = 1 limit 0, 1") - - bindData := s.domain.BindHandle().GetBindRecord(hash.String(), sql, "test") - c.Assert(bindData, NotNil) - c.Check(bindData.OriginalSQL, Equals, "select `a` , `b` from `test` . `t` where `a` = ? limit ...") - bind := bindData.Bindings[0] - c.Check(bind.BindSQL, Equals, "SELECT `a`,`b` FROM `test`.`t` USE INDEX (`ib`) WHERE `a` = 1 LIMIT 0,1") - c.Check(bindData.Db, Equals, "test") - c.Check(bind.Status, Equals, "using") - c.Check(bind.Charset, NotNil) - c.Check(bind.Collation, NotNil) - c.Check(bind.CreateTime, NotNil) - c.Check(bind.UpdateTime, NotNil) -} - -func (s *testSuite) TestDMLSQLBind(c *C) { - tk := testkit.NewTestKit(c, s.store) - s.cleanBindingEnv(tk) - tk.MustExec("use test") - tk.MustExec("drop table if exists t1, t2") - tk.MustExec("create table t1(a int, b int, c int, key idx_b(b), key idx_c(c))") - tk.MustExec("create table t2(a int, b int, c int, key idx_b(b), key idx_c(c))") - - tk.MustExec("delete from t1 where b = 1 and c > 1") - c.Assert(tk.Se.GetSessionVars().StmtCtx.IndexNames[0], Equals, "t1:idx_b") - c.Assert(tk.MustUseIndex("delete from t1 where b = 1 and c > 1", "idx_b(b)"), IsTrue) - tk.MustExec("create global binding for delete from t1 where b = 1 and c > 1 using delete /*+ use_index(t1,idx_c) */ from t1 where b = 1 and c > 1") - tk.MustExec("delete from t1 where b = 1 and c > 1") - c.Assert(tk.Se.GetSessionVars().StmtCtx.IndexNames[0], Equals, "t1:idx_c") - c.Assert(tk.MustUseIndex("delete from t1 where b = 1 and c > 1", "idx_c(c)"), IsTrue) - - c.Assert(tk.HasPlan("delete t1, t2 from t1 inner join t2 on t1.b = t2.b", "HashJoin"), IsTrue) - tk.MustExec("create global binding for delete t1, t2 from t1 inner join t2 on t1.b = t2.b using delete /*+ inl_join(t1) */ t1, t2 from t1 inner join t2 on t1.b = t2.b") - c.Assert(tk.HasPlan("delete t1, t2 from t1 inner join t2 on t1.b = t2.b", "IndexJoin"), IsTrue) - - tk.MustExec("update t1 set a = 1 where b = 1 and c > 1") - c.Assert(tk.Se.GetSessionVars().StmtCtx.IndexNames[0], Equals, "t1:idx_b") - c.Assert(tk.MustUseIndex("update t1 set a = 1 where b = 1 and c > 1", "idx_b(b)"), IsTrue) - tk.MustExec("create global binding for update t1 set a = 1 where b = 1 and c > 1 using update /*+ use_index(t1,idx_c) */ t1 set a = 1 where b = 1 and c > 1") - tk.MustExec("delete from t1 where b = 1 and c > 1") - c.Assert(tk.Se.GetSessionVars().StmtCtx.IndexNames[0], Equals, "t1:idx_c") - c.Assert(tk.MustUseIndex("update t1 set a = 1 where b = 1 and c > 1", "idx_c(c)"), IsTrue) - - c.Assert(tk.HasPlan("update t1, t2 set t1.a = 1 where t1.b = t2.b", "HashJoin"), IsTrue) - tk.MustExec("create global binding for update t1, t2 set t1.a = 1 where t1.b = t2.b using update /*+ inl_join(t1) */ t1, t2 set t1.a = 1 where t1.b = t2.b") - c.Assert(tk.HasPlan("update t1, t2 set t1.a = 1 where t1.b = t2.b", "IndexJoin"), IsTrue) - - tk.MustExec("insert into t1 select * from t2 where t2.b = 2 and t2.c > 2") - c.Assert(tk.Se.GetSessionVars().StmtCtx.IndexNames[0], Equals, "t2:idx_b") - c.Assert(tk.MustUseIndex("insert into t1 select * from t2 where t2.b = 2 and t2.c > 2", "idx_b(b)"), IsTrue) - tk.MustExec("create global binding for insert into t1 select * from t2 where t2.b = 1 and t2.c > 1 using insert /*+ use_index(t2,idx_c) */ into t1 select * from t2 where t2.b = 1 and t2.c > 1") - tk.MustExec("insert into t1 select * from t2 where t2.b = 2 and t2.c > 2") - c.Assert(tk.Se.GetSessionVars().StmtCtx.IndexNames[0], Equals, "t2:idx_b") - c.Assert(tk.MustUseIndex("insert into t1 select * from t2 where t2.b = 2 and t2.c > 2", "idx_b(b)"), IsTrue) - tk.MustExec("drop global binding for insert into t1 select * from t2 where t2.b = 1 and t2.c > 1") - tk.MustExec("create global binding for insert into t1 select * from t2 where t2.b = 1 and t2.c > 1 using insert into t1 select /*+ use_index(t2,idx_c) */ * from t2 where t2.b = 1 and t2.c > 1") - tk.MustExec("insert into t1 select * from t2 where t2.b = 2 and t2.c > 2") - c.Assert(tk.Se.GetSessionVars().StmtCtx.IndexNames[0], Equals, "t2:idx_c") - c.Assert(tk.MustUseIndex("insert into t1 select * from t2 where t2.b = 2 and t2.c > 2", "idx_c(c)"), IsTrue) - - tk.MustExec("replace into t1 select * from t2 where t2.b = 2 and t2.c > 2") - c.Assert(tk.Se.GetSessionVars().StmtCtx.IndexNames[0], Equals, "t2:idx_b") - c.Assert(tk.MustUseIndex("replace into t1 select * from t2 where t2.b = 2 and t2.c > 2", "idx_b(b)"), IsTrue) - tk.MustExec("create global binding for replace into t1 select * from t2 where t2.b = 1 and t2.c > 1 using replace into t1 select /*+ use_index(t2,idx_c) */ * from t2 where t2.b = 1 and t2.c > 1") - tk.MustExec("replace into t1 select * from t2 where t2.b = 2 and t2.c > 2") - c.Assert(tk.Se.GetSessionVars().StmtCtx.IndexNames[0], Equals, "t2:idx_c") - c.Assert(tk.MustUseIndex("replace into t1 select * from t2 where t2.b = 2 and t2.c > 2", "idx_c(c)"), IsTrue) -} - -func (s *testSuite) TestBestPlanInBaselines(c *C) { - tk := testkit.NewTestKit(c, s.store) - s.cleanBindingEnv(tk) - tk.MustExec("use test") - tk.MustExec("drop table if exists t") - tk.MustExec("create table t(a int, b int, INDEX ia (a), INDEX ib (b));") - tk.MustExec("insert into t value(1, 1);") - - // before binding - tk.MustQuery("select a, b from t where a = 3 limit 1, 100") - c.Assert(tk.Se.GetSessionVars().StmtCtx.IndexNames[0], Equals, "t:ia") - c.Assert(tk.MustUseIndex("select a, b from t where a = 3 limit 1, 100", "ia(a)"), IsTrue) - - tk.MustQuery("select a, b from t where b = 3 limit 1, 100") - c.Assert(tk.Se.GetSessionVars().StmtCtx.IndexNames[0], Equals, "t:ib") - c.Assert(tk.MustUseIndex("select a, b from t where b = 3 limit 1, 100", "ib(b)"), IsTrue) - - tk.MustExec(`create global binding for select a, b from t where a = 1 limit 0, 1 using select /*+ use_index(@sel_1 test.t ia) */ a, b from t where a = 1 limit 0, 1`) - tk.MustExec(`create global binding for select a, b from t where b = 1 limit 0, 1 using select /*+ use_index(@sel_1 test.t ib) */ a, b from t where b = 1 limit 0, 1`) - - sql, hash := normalizeWithDefaultDB(c, "select a, b from t where a = 1 limit 0, 1", "test") - bindData := s.domain.BindHandle().GetBindRecord(hash, sql, "test") - c.Check(bindData, NotNil) - c.Check(bindData.OriginalSQL, Equals, "select `a` , `b` from `test` . `t` where `a` = ? limit ...") - bind := bindData.Bindings[0] - c.Check(bind.BindSQL, Equals, "SELECT /*+ use_index(@`sel_1` `test`.`t` `ia`)*/ `a`,`b` FROM `test`.`t` WHERE `a` = 1 LIMIT 0,1") - c.Check(bindData.Db, Equals, "test") - c.Check(bind.Status, Equals, "using") - - tk.MustQuery("select a, b from t where a = 3 limit 1, 10") - c.Assert(tk.Se.GetSessionVars().StmtCtx.IndexNames[0], Equals, "t:ia") - c.Assert(tk.MustUseIndex("select a, b from t where a = 3 limit 1, 100", "ia(a)"), IsTrue) - - tk.MustQuery("select a, b from t where b = 3 limit 1, 100") - c.Assert(tk.Se.GetSessionVars().StmtCtx.IndexNames[0], Equals, "t:ib") - c.Assert(tk.MustUseIndex("select a, b from t where b = 3 limit 1, 100", "ib(b)"), IsTrue) -} - -func (s *testSuite) TestErrorBind(c *C) { - tk := testkit.NewTestKit(c, s.store) - s.cleanBindingEnv(tk) - tk.MustExec("use test") - tk.MustGetErrMsg("create global binding for select * from t using select * from t", "[schema:1146]Table 'test.t' doesn't exist") - tk.MustExec("drop table if exists t") - tk.MustExec("drop table if exists t1") - tk.MustExec("create table t(i int, s varchar(20))") - tk.MustExec("create table t1(i int, s varchar(20))") - tk.MustExec("create index index_t on t(i,s)") - - _, err := tk.Exec("create global binding for select * from t where i>100 using select * from t use index(index_t) where i>100") - c.Assert(err, IsNil, Commentf("err %v", err)) - - sql, hash := parser.NormalizeDigest("select * from test . t where i > ?") - bindData := s.domain.BindHandle().GetBindRecord(hash.String(), sql, "test") - c.Check(bindData, NotNil) - c.Check(bindData.OriginalSQL, Equals, "select * from `test` . `t` where `i` > ?") - bind := bindData.Bindings[0] - c.Check(bind.BindSQL, Equals, "SELECT * FROM `test`.`t` USE INDEX (`index_t`) WHERE `i` > 100") - c.Check(bindData.Db, Equals, "test") - c.Check(bind.Status, Equals, "using") - c.Check(bind.Charset, NotNil) - c.Check(bind.Collation, NotNil) - c.Check(bind.CreateTime, NotNil) - c.Check(bind.UpdateTime, NotNil) - - tk.MustExec("drop index index_t on t") - _, err = tk.Exec("select * from t where i > 10") - c.Check(err, IsNil) - - s.domain.BindHandle().DropInvalidBindRecord() - - rs, err := tk.Exec("show global bindings") - c.Assert(err, IsNil) - chk := rs.NewChunk() - err = rs.Next(context.TODO(), chk) - c.Check(err, IsNil) - c.Check(chk.NumRows(), Equals, 0) -} - -func (s *testSuite) TestDMLEvolveBaselines(c *C) { - originalVal := config.CheckTableBeforeDrop - config.CheckTableBeforeDrop = true - defer func() { - config.CheckTableBeforeDrop = originalVal - }() - - tk := testkit.NewTestKit(c, s.store) - s.cleanBindingEnv(tk) - tk.MustExec("use test") - tk.MustExec("drop table if exists t") - tk.MustExec("create table t(a int, b int, c int, index idx_b(b), index idx_c(c))") - tk.MustExec("insert into t values (1,1,1), (2,2,2), (3,3,3), (4,4,4), (5,5,5)") - tk.MustExec("analyze table t") - tk.MustExec("set @@tidb_evolve_plan_baselines=1") - - tk.MustExec("create global binding for delete from t where b = 1 and c > 1 using delete /*+ use_index(t,idx_c) */ from t where b = 1 and c > 1") - rows := tk.MustQuery("show global bindings").Rows() - c.Assert(len(rows), Equals, 1) - tk.MustExec("delete /*+ use_index(t,idx_b) */ from t where b = 2 and c > 1") - c.Assert(tk.Se.GetSessionVars().StmtCtx.IndexNames[0], Equals, "t:idx_c") - tk.MustExec("admin flush bindings") - rows = tk.MustQuery("show global bindings").Rows() - c.Assert(len(rows), Equals, 1) - tk.MustExec("admin evolve bindings") - rows = tk.MustQuery("show global bindings").Rows() - c.Assert(len(rows), Equals, 1) - - tk.MustExec("create global binding for update t set a = 1 where b = 1 and c > 1 using update /*+ use_index(t,idx_c) */ t set a = 1 where b = 1 and c > 1") - rows = tk.MustQuery("show global bindings").Rows() - c.Assert(len(rows), Equals, 2) - tk.MustExec("update /*+ use_index(t,idx_b) */ t set a = 2 where b = 2 and c > 1") - c.Assert(tk.Se.GetSessionVars().StmtCtx.IndexNames[0], Equals, "t:idx_c") - tk.MustExec("admin flush bindings") - rows = tk.MustQuery("show global bindings").Rows() - c.Assert(len(rows), Equals, 2) - tk.MustExec("admin evolve bindings") - rows = tk.MustQuery("show global bindings").Rows() - c.Assert(len(rows), Equals, 2) - - tk.MustExec("create table t1 like t") - tk.MustExec("create global binding for insert into t1 select * from t where t.b = 1 and t.c > 1 using insert into t1 select /*+ use_index(t,idx_c) */ * from t where t.b = 1 and t.c > 1") - rows = tk.MustQuery("show global bindings").Rows() - c.Assert(len(rows), Equals, 3) - tk.MustExec("insert into t1 select /*+ use_index(t,idx_b) */ * from t where t.b = 2 and t.c > 2") - c.Assert(tk.Se.GetSessionVars().StmtCtx.IndexNames[0], Equals, "t:idx_c") - tk.MustExec("admin flush bindings") - rows = tk.MustQuery("show global bindings").Rows() - c.Assert(len(rows), Equals, 3) - tk.MustExec("admin evolve bindings") - rows = tk.MustQuery("show global bindings").Rows() - c.Assert(len(rows), Equals, 3) - - tk.MustExec("create global binding for replace into t1 select * from t where t.b = 1 and t.c > 1 using replace into t1 select /*+ use_index(t,idx_c) */ * from t where t.b = 1 and t.c > 1") - rows = tk.MustQuery("show global bindings").Rows() - c.Assert(len(rows), Equals, 4) - tk.MustExec("replace into t1 select /*+ use_index(t,idx_b) */ * from t where t.b = 2 and t.c > 2") - c.Assert(tk.Se.GetSessionVars().StmtCtx.IndexNames[0], Equals, "t:idx_c") - tk.MustExec("admin flush bindings") - rows = tk.MustQuery("show global bindings").Rows() - c.Assert(len(rows), Equals, 4) - tk.MustExec("admin evolve bindings") - rows = tk.MustQuery("show global bindings").Rows() - c.Assert(len(rows), Equals, 4) -} - -func (s *testSuite) TestAddEvolveTasks(c *C) { - originalVal := config.CheckTableBeforeDrop - config.CheckTableBeforeDrop = true - defer func() { - config.CheckTableBeforeDrop = originalVal - }() - - tk := testkit.NewTestKit(c, s.store) - s.cleanBindingEnv(tk) - tk.MustExec("use test") - tk.MustExec("drop table if exists t") - tk.MustExec("create table t(a int, b int, c int, index idx_a(a), index idx_b(b), index idx_c(c))") - tk.MustExec("insert into t values (1,1,1), (2,2,2), (3,3,3), (4,4,4), (5,5,5)") - tk.MustExec("analyze table t") - tk.MustExec("create global binding for select * from t where a >= 1 and b >= 1 and c = 0 using select * from t use index(idx_a) where a >= 1 and b >= 1 and c = 0") - tk.MustExec("set @@tidb_evolve_plan_baselines=1") - // It cannot choose table path although it has lowest cost. - tk.MustQuery("select * from t where a >= 4 and b >= 1 and c = 0") - c.Assert(tk.Se.GetSessionVars().StmtCtx.IndexNames[0], Equals, "t:idx_a") - tk.MustExec("admin flush bindings") - rows := tk.MustQuery("show global bindings").Rows() - c.Assert(len(rows), Equals, 2) - c.Assert(rows[0][1], Equals, "SELECT /*+ use_index(@`sel_1` `test`.`t` )*/ * FROM `test`.`t` WHERE `a` >= 4 AND `b` >= 1 AND `c` = 0") - c.Assert(rows[0][3], Equals, "pending verify") - tk.MustExec("admin evolve bindings") - rows = tk.MustQuery("show global bindings").Rows() - c.Assert(len(rows), Equals, 2) - c.Assert(rows[0][1], Equals, "SELECT /*+ use_index(@`sel_1` `test`.`t` )*/ * FROM `test`.`t` WHERE `a` >= 4 AND `b` >= 1 AND `c` = 0") - status := rows[0][3].(string) - c.Assert(status == "using" || status == "rejected", IsTrue) -} - -func (s *testSuite) TestRuntimeHintsInEvolveTasks(c *C) { - originalVal := config.CheckTableBeforeDrop - config.CheckTableBeforeDrop = true - defer func() { - config.CheckTableBeforeDrop = originalVal - }() - - tk := testkit.NewTestKit(c, s.store) - s.cleanBindingEnv(tk) - tk.MustExec("use test") - tk.MustExec("drop table if exists t") - tk.MustExec("set @@tidb_evolve_plan_baselines=1") - tk.MustExec("create table t(a int, b int, c int, index idx_a(a), index idx_b(b), index idx_c(c))") - - tk.MustExec("create global binding for select * from t where a >= 1 and b >= 1 and c = 0 using select * from t use index(idx_a) where a >= 1 and b >= 1 and c = 0") - tk.MustQuery("select /*+ MAX_EXECUTION_TIME(5000) */ * from t where a >= 4 and b >= 1 and c = 0") - tk.MustExec("admin flush bindings") - rows := tk.MustQuery("show global bindings").Rows() - c.Assert(len(rows), Equals, 2) - c.Assert(rows[0][1], Equals, "SELECT /*+ use_index(@`sel_1` `test`.`t` `idx_c`), max_execution_time(5000)*/ * FROM `test`.`t` WHERE `a` >= 4 AND `b` >= 1 AND `c` = 0") -} - -func (s *testSuite) TestDefaultSessionVars(c *C) { - tk := testkit.NewTestKit(c, s.store) - s.cleanBindingEnv(tk) - tk.MustQuery(`show variables like "%baselines%"`).Sort().Check(testkit.Rows( - "tidb_capture_plan_baselines OFF", - "tidb_evolve_plan_baselines OFF", - "tidb_use_plan_baselines ON")) - tk.MustQuery(`show global variables like "%baselines%"`).Sort().Check(testkit.Rows( - "tidb_capture_plan_baselines OFF", - "tidb_evolve_plan_baselines OFF", - "tidb_use_plan_baselines ON")) -} - -func (s *testSuite) TestCaptureBaselinesScope(c *C) { - tk1 := testkit.NewTestKit(c, s.store) - tk2 := testkit.NewTestKit(c, s.store) - s.cleanBindingEnv(tk1) - tk1.MustQuery(`show session variables like "tidb_capture_plan_baselines"`).Check(testkit.Rows( - "tidb_capture_plan_baselines OFF", - )) - tk1.MustQuery(`show global variables like "tidb_capture_plan_baselines"`).Check(testkit.Rows( - "tidb_capture_plan_baselines OFF", - )) - tk1.MustQuery(`select @@session.tidb_capture_plan_baselines`).Check(testkit.Rows( - "0", - )) - tk1.MustQuery(`select @@global.tidb_capture_plan_baselines`).Check(testkit.Rows( - "0", - )) - - tk1.MustExec("set @@session.tidb_capture_plan_baselines = on") - defer func() { - tk1.MustExec(" set @@session.tidb_capture_plan_baselines = off") - }() - tk1.MustQuery(`show session variables like "tidb_capture_plan_baselines"`).Check(testkit.Rows( - "tidb_capture_plan_baselines ON", - )) - tk1.MustQuery(`show global variables like "tidb_capture_plan_baselines"`).Check(testkit.Rows( - "tidb_capture_plan_baselines OFF", - )) - tk1.MustQuery(`select @@session.tidb_capture_plan_baselines`).Check(testkit.Rows( - "1", - )) - tk1.MustQuery(`select @@global.tidb_capture_plan_baselines`).Check(testkit.Rows( - "0", - )) - tk2.MustQuery(`show session variables like "tidb_capture_plan_baselines"`).Check(testkit.Rows( - "tidb_capture_plan_baselines ON", - )) - tk2.MustQuery(`show global variables like "tidb_capture_plan_baselines"`).Check(testkit.Rows( - "tidb_capture_plan_baselines OFF", - )) - tk2.MustQuery(`select @@session.tidb_capture_plan_baselines`).Check(testkit.Rows( - "1", - )) - tk2.MustQuery(`select @@global.tidb_capture_plan_baselines`).Check(testkit.Rows( - "0", - )) -} - -func (s *testSuite) TestStmtHints(c *C) { - tk := testkit.NewTestKit(c, s.store) - s.cleanBindingEnv(tk) - tk.MustExec("use test") - tk.MustExec("drop table if exists t") - tk.MustExec("create table t(a int, b int, index idx(a))") - tk.MustExec("create global binding for select * from t using select /*+ MAX_EXECUTION_TIME(100), MEMORY_QUOTA(1 GB) */ * from t use index(idx)") - tk.MustQuery("select * from t") - c.Assert(tk.Se.GetSessionVars().StmtCtx.MemQuotaQuery, Equals, int64(1073741824)) - c.Assert(tk.Se.GetSessionVars().StmtCtx.MaxExecutionTime, Equals, uint64(100)) - tk.MustQuery("select a, b from t") - c.Assert(tk.Se.GetSessionVars().StmtCtx.MemQuotaQuery, Equals, int64(0)) - c.Assert(tk.Se.GetSessionVars().StmtCtx.MaxExecutionTime, Equals, uint64(0)) -} - -func (s *testSuite) TestPrivileges(c *C) { - tk := testkit.NewTestKit(c, s.store) - s.cleanBindingEnv(tk) - tk.MustExec("use test") - tk.MustExec("drop table if exists t") - tk.MustExec("create table t(a int, b int, index idx(a))") - tk.MustExec("create global binding for select * from t using select * from t use index(idx)") - c.Assert(tk.Se.Auth(&auth.UserIdentity{Username: "root", Hostname: "%"}, nil, nil), IsTrue) - rows := tk.MustQuery("show global bindings").Rows() - c.Assert(len(rows), Equals, 1) - tk.MustExec("create user test@'%'") - c.Assert(tk.Se.Auth(&auth.UserIdentity{Username: "test", Hostname: "%"}, nil, nil), IsTrue) - rows = tk.MustQuery("show global bindings").Rows() - c.Assert(len(rows), Equals, 0) -} - -func (s *testSuite) TestHintsSetEvolveTask(c *C) { - originalVal := config.CheckTableBeforeDrop - config.CheckTableBeforeDrop = true - defer func() { - config.CheckTableBeforeDrop = originalVal - }() - - tk := testkit.NewTestKit(c, s.store) - s.cleanBindingEnv(tk) - tk.MustExec("use test") - tk.MustExec("drop table if exists t") - tk.MustExec("create table t(a int, index idx_a(a))") - tk.MustExec("create global binding for select * from t where a > 10 using select * from t ignore index(idx_a) where a > 10") - tk.MustExec("set @@tidb_evolve_plan_baselines=1") - tk.MustQuery("select * from t use index(idx_a) where a > 0") - bindHandle := s.domain.BindHandle() - bindHandle.SaveEvolveTasksToStore() - // Verify the added Binding for evolution contains valid ID and Hint, otherwise, panic may happen. - sql, hash := normalizeWithDefaultDB(c, "select * from t where a > ?", "test") - bindData := bindHandle.GetBindRecord(hash, sql, "test") - c.Check(bindData, NotNil) - c.Check(bindData.OriginalSQL, Equals, "select * from `test` . `t` where `a` > ?") - c.Assert(len(bindData.Bindings), Equals, 2) - bind := bindData.Bindings[1] - c.Assert(bind.Status, Equals, bindinfo.PendingVerify) - c.Assert(bind.ID, Not(Equals), "") - c.Assert(bind.Hint, NotNil) -} - -func (s *testSuite) TestHintsSetID(c *C) { - tk := testkit.NewTestKit(c, s.store) - s.cleanBindingEnv(tk) - tk.MustExec("use test") - tk.MustExec("drop table if exists t") - tk.MustExec("create table t(a int, index idx_a(a))") - tk.MustExec("create global binding for select * from t where a > 10 using select /*+ use_index(test.t, idx_a) */ * from t where a > 10") - bindHandle := s.domain.BindHandle() - // Verify the added Binding contains ID with restored query block. - sql, hash := normalizeWithDefaultDB(c, "select * from t where a > ?", "test") - bindData := bindHandle.GetBindRecord(hash, sql, "test") - c.Check(bindData, NotNil) - c.Check(bindData.OriginalSQL, Equals, "select * from `test` . `t` where `a` > ?") - c.Assert(len(bindData.Bindings), Equals, 1) - bind := bindData.Bindings[0] - c.Assert(bind.ID, Equals, "use_index(@`sel_1` `test`.`t` `idx_a`)") - - s.cleanBindingEnv(tk) - tk.MustExec("create global binding for select * from t where a > 10 using select /*+ use_index(t, idx_a) */ * from t where a > 10") - bindData = bindHandle.GetBindRecord(hash, sql, "test") - c.Check(bindData, NotNil) - c.Check(bindData.OriginalSQL, Equals, "select * from `test` . `t` where `a` > ?") - c.Assert(len(bindData.Bindings), Equals, 1) - bind = bindData.Bindings[0] - c.Assert(bind.ID, Equals, "use_index(@`sel_1` `test`.`t` `idx_a`)") - - s.cleanBindingEnv(tk) - tk.MustExec("create global binding for select * from t where a > 10 using select /*+ use_index(@sel_1 t, idx_a) */ * from t where a > 10") - bindData = bindHandle.GetBindRecord(hash, sql, "test") - c.Check(bindData, NotNil) - c.Check(bindData.OriginalSQL, Equals, "select * from `test` . `t` where `a` > ?") - c.Assert(len(bindData.Bindings), Equals, 1) - bind = bindData.Bindings[0] - c.Assert(bind.ID, Equals, "use_index(@`sel_1` `test`.`t` `idx_a`)") - - s.cleanBindingEnv(tk) - tk.MustExec("create global binding for select * from t where a > 10 using select /*+ use_index(@qb1 t, idx_a) qb_name(qb1) */ * from t where a > 10") - bindData = bindHandle.GetBindRecord(hash, sql, "test") - c.Check(bindData, NotNil) - c.Check(bindData.OriginalSQL, Equals, "select * from `test` . `t` where `a` > ?") - c.Assert(len(bindData.Bindings), Equals, 1) - bind = bindData.Bindings[0] - c.Assert(bind.ID, Equals, "use_index(@`sel_1` `test`.`t` `idx_a`)") - - s.cleanBindingEnv(tk) - tk.MustExec("create global binding for select * from t where a > 10 using select /*+ use_index(T, IDX_A) */ * from t where a > 10") - bindData = bindHandle.GetBindRecord(hash, sql, "test") - c.Check(bindData, NotNil) - c.Check(bindData.OriginalSQL, Equals, "select * from `test` . `t` where `a` > ?") - c.Assert(len(bindData.Bindings), Equals, 1) - bind = bindData.Bindings[0] - c.Assert(bind.ID, Equals, "use_index(@`sel_1` `test`.`t` `idx_a`)") - - s.cleanBindingEnv(tk) - err := tk.ExecToErr("create global binding for select * from t using select /*+ non_exist_hint() */ * from t") - c.Assert(terror.ErrorEqual(err, parser.ErrWarnOptimizerHintParseError), IsTrue) - tk.MustExec("create global binding for select * from t where a > 10 using select * from t where a > 10") - bindData = bindHandle.GetBindRecord(hash, sql, "test") - c.Check(bindData, NotNil) - c.Check(bindData.OriginalSQL, Equals, "select * from `test` . `t` where `a` > ?") - c.Assert(len(bindData.Bindings), Equals, 1) - bind = bindData.Bindings[0] - c.Assert(bind.ID, Equals, "") -} - -func (s *testSuite) TestNotEvolvePlanForReadStorageHint(c *C) { - originalVal := config.CheckTableBeforeDrop - config.CheckTableBeforeDrop = true - defer func() { - config.CheckTableBeforeDrop = originalVal - }() - - tk := testkit.NewTestKit(c, s.store) - s.cleanBindingEnv(tk) - tk.MustExec("use test") - tk.MustExec("drop table if exists t") - tk.MustExec("create table t(a int, b int, index idx_a(a), index idx_b(b))") - tk.MustExec("insert into t values (1,1), (2,2), (3,3), (4,4), (5,5), (6,6), (7,7), (8,8), (9,9), (10,10)") - tk.MustExec("analyze table t") - // Create virtual tiflash replica info. - dom := domain.GetDomain(tk.Se) - is := dom.InfoSchema() - db, exists := is.SchemaByName(model.NewCIStr("test")) - c.Assert(exists, IsTrue) - for _, tblInfo := range db.Tables { - if tblInfo.Name.L == "t" { - tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{ - Count: 1, - Available: true, - } - } - } - - // Make sure the best plan of the SQL is use TiKV index. - tk.MustExec("set @@session.tidb_executor_concurrency = 4;") - rows := tk.MustQuery("explain select * from t where a >= 11 and b >= 11").Rows() - c.Assert(fmt.Sprintf("%v", rows[len(rows)-1][2]), Equals, "cop[tikv]") - - tk.MustExec("create global binding for select * from t where a >= 1 and b >= 1 using select /*+ read_from_storage(tiflash[t]) */ * from t where a >= 1 and b >= 1") - tk.MustExec("set @@tidb_evolve_plan_baselines=1") - - // Even if index of TiKV has lower cost, it chooses TiFlash. - rows = tk.MustQuery("explain select * from t where a >= 11 and b >= 11").Rows() - c.Assert(fmt.Sprintf("%v", rows[len(rows)-1][2]), Equals, "cop[tiflash]") - - tk.MustExec("admin flush bindings") - rows = tk.MustQuery("show global bindings").Rows() - // None evolve task, because of the origin binding is a read_from_storage binding. - c.Assert(len(rows), Equals, 1) - c.Assert(rows[0][1], Equals, "SELECT /*+ read_from_storage(tiflash[`t`])*/ * FROM `test`.`t` WHERE `a` >= 1 AND `b` >= 1") - c.Assert(rows[0][3], Equals, "using") -} - -func (s *testSuite) TestBindingWithIsolationRead(c *C) { - tk := testkit.NewTestKit(c, s.store) - s.cleanBindingEnv(tk) - tk.MustExec("use test") - tk.MustExec("drop table if exists t") - tk.MustExec("create table t(a int, b int, index idx_a(a), index idx_b(b))") - tk.MustExec("insert into t values (1,1), (2,2), (3,3), (4,4), (5,5), (6,6), (7,7), (8,8), (9,9), (10,10)") - tk.MustExec("analyze table t") - // Create virtual tiflash replica info. - dom := domain.GetDomain(tk.Se) - is := dom.InfoSchema() - db, exists := is.SchemaByName(model.NewCIStr("test")) - c.Assert(exists, IsTrue) - for _, tblInfo := range db.Tables { - if tblInfo.Name.L == "t" { - tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{ - Count: 1, - Available: true, - } - } - } - tk.MustExec("create global binding for select * from t where a >= 1 and b >= 1 using select * from t use index(idx_a) where a >= 1 and b >= 1") - tk.MustExec("set @@tidb_use_plan_baselines = 1") - rows := tk.MustQuery("explain select * from t where a >= 11 and b >= 11").Rows() - c.Assert(rows[len(rows)-1][2], Equals, "cop[tikv]") - // Even if we build a binding use index for SQL, but after we set the isolation read for TiFlash, it choose TiFlash instead of index of TiKV. - tk.MustExec("set @@tidb_isolation_read_engines = \"tiflash\"") - rows = tk.MustQuery("explain select * from t where a >= 11 and b >= 11").Rows() - c.Assert(rows[len(rows)-1][2], Equals, "cop[tiflash]") -} - -func (s *testSuite) TestReCreateBindAfterEvolvePlan(c *C) { - originalVal := config.CheckTableBeforeDrop - config.CheckTableBeforeDrop = true - defer func() { - config.CheckTableBeforeDrop = originalVal - }() - - tk := testkit.NewTestKit(c, s.store) - s.cleanBindingEnv(tk) - tk.MustExec("use test") - tk.MustExec("drop table if exists t") - tk.MustExec("create table t(a int, b int, c int, index idx_a(a), index idx_b(b), index idx_c(c))") - tk.MustExec("insert into t values (1,1,1), (2,2,2), (3,3,3), (4,4,4), (5,5,5)") - tk.MustExec("analyze table t") - tk.MustExec("create global binding for select * from t where a >= 1 and b >= 1 using select * from t use index(idx_a) where a >= 1 and b >= 1") - tk.MustExec("set @@tidb_evolve_plan_baselines=1") - - // It cannot choose table path although it has lowest cost. - tk.MustQuery("select * from t where a >= 0 and b >= 0") - c.Assert(tk.Se.GetSessionVars().StmtCtx.IndexNames[0], Equals, "t:idx_a") - - tk.MustExec("admin flush bindings") - rows := tk.MustQuery("show global bindings").Rows() - c.Assert(len(rows), Equals, 2) - c.Assert(rows[0][1], Equals, "SELECT /*+ use_index(@`sel_1` `test`.`t` )*/ * FROM `test`.`t` WHERE `a` >= 0 AND `b` >= 0") - c.Assert(rows[0][3], Equals, "pending verify") - - tk.MustExec("create global binding for select * from t where a >= 1 and b >= 1 using select * from t use index(idx_b) where a >= 1 and b >= 1") - rows = tk.MustQuery("show global bindings").Rows() - c.Assert(len(rows), Equals, 1) - tk.MustQuery("select * from t where a >= 4 and b >= 1") - c.Assert(tk.Se.GetSessionVars().StmtCtx.IndexNames[0], Equals, "t:idx_b") -} - -func (s *testSuite) TestInvisibleIndex(c *C) { - tk := testkit.NewTestKit(c, s.store) - s.cleanBindingEnv(tk) - tk.MustExec("use test") - tk.MustExec("drop table if exists t") - tk.MustExec("create table t(a int, b int, unique idx_a(a), index idx_b(b) invisible)") - tk.MustGetErrMsg( - "create global binding for select * from t using select * from t use index(idx_b) ", - "[planner:1176]Key 'idx_b' doesn't exist in table 't'") - - // Create bind using index - tk.MustExec("create global binding for select * from t using select * from t use index(idx_a) ") - - tk.MustQuery("select * from t") - c.Assert(tk.Se.GetSessionVars().StmtCtx.IndexNames[0], Equals, "t:idx_a") - c.Assert(tk.MustUseIndex("select * from t", "idx_a(a)"), IsTrue) - - tk.MustExec(`prepare stmt1 from 'select * from t'`) - tk.MustExec("execute stmt1") - c.Assert(len(tk.Se.GetSessionVars().StmtCtx.IndexNames), Equals, 1) - c.Assert(tk.Se.GetSessionVars().StmtCtx.IndexNames[0], Equals, "t:idx_a") - - // And then make this index invisible - tk.MustExec("alter table t alter index idx_a invisible") - tk.MustQuery("select * from t") - c.Assert(len(tk.Se.GetSessionVars().StmtCtx.IndexNames), Equals, 0) - - tk.MustExec("execute stmt1") - c.Assert(len(tk.Se.GetSessionVars().StmtCtx.IndexNames), Equals, 0) - - tk.MustExec("drop binding for select * from t") -} - -func (s *testSuite) TestSPMHitInfo(c *C) { - tk := testkit.NewTestKit(c, s.store) - s.cleanBindingEnv(tk) - tk.MustExec("use test") - tk.MustExec("drop table if exists t1") - tk.MustExec("drop table if exists t2") - tk.MustExec("create table t1(id int)") - tk.MustExec("create table t2(id int)") - - c.Assert(tk.HasPlan("SELECT * from t1,t2 where t1.id = t2.id", "HashJoin"), IsTrue) - c.Assert(tk.HasPlan("SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id", "MergeJoin"), IsTrue) - - tk.MustExec("SELECT * from t1,t2 where t1.id = t2.id") - tk.MustQuery(`select @@last_plan_from_binding;`).Check(testkit.Rows("0")) - tk.MustExec("create global binding for SELECT * from t1,t2 where t1.id = t2.id using SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id") - - c.Assert(tk.HasPlan("SELECT * from t1,t2 where t1.id = t2.id", "MergeJoin"), IsTrue) - tk.MustExec("SELECT * from t1,t2 where t1.id = t2.id") - tk.MustQuery(`select @@last_plan_from_binding;`).Check(testkit.Rows("1")) - tk.MustExec("drop global binding for SELECT * from t1,t2 where t1.id = t2.id") -} - -func (s *testSuite) TestReCreateBind(c *C) { - tk := testkit.NewTestKit(c, s.store) - s.cleanBindingEnv(tk) - tk.MustExec("use test") - tk.MustExec("drop table if exists t") - tk.MustExec("create table t(a int, b int, index idx(a))") - - tk.MustQuery("select * from mysql.bind_info where source != 'builtin'").Check(testkit.Rows()) - tk.MustQuery("show global bindings").Check(testkit.Rows()) - - tk.MustExec("create global binding for select * from t using select * from t") - tk.MustQuery("select original_sql, status from mysql.bind_info where source != 'builtin';").Check(testkit.Rows( - "select * from `test` . `t` using", - )) - rows := tk.MustQuery("show global bindings").Rows() - c.Assert(len(rows), Equals, 1) - c.Assert(rows[0][0], Equals, "select * from `test` . `t`") - c.Assert(rows[0][3], Equals, "using") - - tk.MustExec("create global binding for select * from t using select * from t") - rows = tk.MustQuery("show global bindings").Rows() - c.Assert(len(rows), Equals, 1) - c.Assert(rows[0][0], Equals, "select * from `test` . `t`") - c.Assert(rows[0][3], Equals, "using") - - rows = tk.MustQuery("select original_sql, status from mysql.bind_info where source != 'builtin';").Rows() - c.Assert(len(rows), Equals, 2) - c.Assert(rows[0][1], Equals, "deleted") - c.Assert(rows[1][1], Equals, "using") -} - -func (s *testSuite) TestExplainShowBindSQL(c *C) { - tk := testkit.NewTestKit(c, s.store) - s.cleanBindingEnv(tk) - tk.MustExec("use test") - tk.MustExec("drop table if exists t") - tk.MustExec("create table t(a int, b int, key(a))") - - tk.MustExec("create global binding for select * from t using select * from t use index(a)") - tk.MustQuery("select original_sql, bind_sql from mysql.bind_info where default_db != 'mysql'").Check(testkit.Rows( - "select * from `test` . `t` SELECT * FROM `test`.`t` USE INDEX (`a`)", - )) - - tk.MustExec("explain format = 'verbose' select * from t") - tk.MustQuery("show warnings").Check(testkit.Rows("Note 1105 Using the bindSQL: SELECT * FROM `test`.`t` USE INDEX (`a`)")) - // explain analyze do not support verbose yet. -} - -func (s *testSuite) TestDMLIndexHintBind(c *C) { - tk := testkit.NewTestKit(c, s.store) - s.cleanBindingEnv(tk) - tk.MustExec("use test") - tk.MustExec("create table t(a int, b int, c int, key idx_b(b), key idx_c(c))") - - tk.MustExec("delete from t where b = 1 and c > 1") - c.Assert(tk.Se.GetSessionVars().StmtCtx.IndexNames[0], Equals, "t:idx_b") - c.Assert(tk.MustUseIndex("delete from t where b = 1 and c > 1", "idx_b(b)"), IsTrue) - tk.MustExec("create global binding for delete from t where b = 1 and c > 1 using delete from t use index(idx_c) where b = 1 and c > 1") - tk.MustExec("delete from t where b = 1 and c > 1") - c.Assert(tk.Se.GetSessionVars().StmtCtx.IndexNames[0], Equals, "t:idx_c") - c.Assert(tk.MustUseIndex("delete from t where b = 1 and c > 1", "idx_c(c)"), IsTrue) -} - -func (s *testSuite) TestForbidEvolvePlanBaseLinesBeforeGA(c *C) { - originalVal := config.CheckTableBeforeDrop - config.CheckTableBeforeDrop = false - defer func() { - config.CheckTableBeforeDrop = originalVal - }() - - tk := testkit.NewTestKit(c, s.store) - s.cleanBindingEnv(tk) - err := tk.ExecToErr("set @@tidb_evolve_plan_baselines=0") - c.Assert(err, Equals, nil) - err = tk.ExecToErr("set @@TiDB_Evolve_pLan_baselines=1") - c.Assert(err, ErrorMatches, "Cannot enable baseline evolution feature, it is not generally available now") - err = tk.ExecToErr("set @@TiDB_Evolve_pLan_baselines=oN") - c.Assert(err, ErrorMatches, "Cannot enable baseline evolution feature, it is not generally available now") - err = tk.ExecToErr("admin evolve bindings") - c.Assert(err, ErrorMatches, "Cannot enable baseline evolution feature, it is not generally available now") -} - -func (s *testSuite) TestExplainTableStmts(c *C) { - tk := testkit.NewTestKit(c, s.store) - s.cleanBindingEnv(tk) - tk.MustExec("use test") - tk.MustExec("drop table if exists t") - tk.MustExec("create table t(id int, value decimal(5,2))") - tk.MustExec("table t") - tk.MustExec("explain table t") - tk.MustExec("desc table t") -} - -func (s *testSuite) TestSPMWithoutUseDatabase(c *C) { - tk := testkit.NewTestKit(c, s.store) - tk1 := testkit.NewTestKit(c, s.store) - s.cleanBindingEnv(tk) - s.cleanBindingEnv(tk1) - tk.MustExec("use test") - tk.MustExec("drop table if exists t") - tk.MustExec("create table t(a int, b int, key(a))") - tk.MustExec("create global binding for select * from t using select * from t force index(a)") - - err := tk1.ExecToErr("select * from t") - c.Assert(err, ErrorMatches, "*No database selected") - tk1.MustQuery(`select @@last_plan_from_binding;`).Check(testkit.Rows("0")) - c.Assert(tk1.MustUseIndex("select * from test.t", "a"), IsTrue) - tk1.MustExec("select * from test.t") - tk1.MustQuery(`select @@last_plan_from_binding;`).Check(testkit.Rows("1")) -} - -func (s *testSuite) TestBindingWithoutCharset(c *C) { - tk := testkit.NewTestKit(c, s.store) - s.cleanBindingEnv(tk) - tk.MustExec("use test") - tk.MustExec("drop table if exists t") - tk.MustExec("create table t (a varchar(10) CHARACTER SET utf8)") - tk.MustExec("create global binding for select * from t where a = 'aa' using select * from t where a = 'aa'") - rows := tk.MustQuery("show global bindings").Rows() - c.Assert(len(rows), Equals, 1) - c.Assert(rows[0][0], Equals, "select * from `test` . `t` where `a` = ?") - c.Assert(rows[0][1], Equals, "SELECT * FROM `test`.`t` WHERE `a` = 'aa'") -} - -func (s *testSuite) TestGCBindRecord(c *C) { - tk := testkit.NewTestKit(c, s.store) - s.cleanBindingEnv(tk) - tk.MustExec("use test") - tk.MustExec("drop table if exists t") - tk.MustExec("create table t(a int, b int, key(a))") - - tk.MustExec("create global binding for select * from t where a = 1 using select * from t use index(a) where a = 1") - rows := tk.MustQuery("show global bindings").Rows() - c.Assert(len(rows), Equals, 1) - c.Assert(rows[0][0], Equals, "select * from `test` . `t` where `a` = ?") - c.Assert(rows[0][3], Equals, "using") - tk.MustQuery("select status from mysql.bind_info where original_sql = 'select * from `test` . `t` where `a` = ?'").Check(testkit.Rows( - "using", - )) - - h := s.domain.BindHandle() - // bindinfo.Lease is set to 0 for test env in SetUpSuite. - c.Assert(h.GCBindRecord(), IsNil) - rows = tk.MustQuery("show global bindings").Rows() - c.Assert(len(rows), Equals, 1) - c.Assert(rows[0][0], Equals, "select * from `test` . `t` where `a` = ?") - c.Assert(rows[0][3], Equals, "using") - tk.MustQuery("select status from mysql.bind_info where original_sql = 'select * from `test` . `t` where `a` = ?'").Check(testkit.Rows( - "using", - )) - - tk.MustExec("drop global binding for select * from t where a = 1") - tk.MustQuery("show global bindings").Check(testkit.Rows()) - tk.MustQuery("select status from mysql.bind_info where original_sql = 'select * from `test` . `t` where `a` = ?'").Check(testkit.Rows( - "deleted", - )) - c.Assert(h.GCBindRecord(), IsNil) - tk.MustQuery("show global bindings").Check(testkit.Rows()) - tk.MustQuery("select status from mysql.bind_info where original_sql = 'select * from `test` . `t` where `a` = ?'").Check(testkit.Rows()) -} diff --git a/bindinfo/handle_serial_test.go b/bindinfo/handle_serial_test.go index af00190e9dd4a..65fda9f804771 100644 --- a/bindinfo/handle_serial_test.go +++ b/bindinfo/handle_serial_test.go @@ -224,6 +224,125 @@ func TestEvolveInvalidBindings(t *testing.T) { require.True(t, status == "using" || status == "rejected") } +var testSQLs = []struct { + createSQL string + overlaySQL string + querySQL string + originSQL string + bindSQL string + dropSQL string + memoryUsage float64 +}{ + { + createSQL: "binding for select * from t where i>100 using select * from t use index(index_t) where i>100", + overlaySQL: "binding for select * from t where i>99 using select * from t use index(index_t) where i>99", + querySQL: "select * from t where i > 30.0", + originSQL: "select * from `test` . `t` where `i` > ?", + bindSQL: "SELECT * FROM `test`.`t` USE INDEX (`index_t`) WHERE `i` > 99", + dropSQL: "binding for select * from t where i>100", + memoryUsage: float64(144), + }, + { + createSQL: "binding for select * from t union all select * from t using select * from t use index(index_t) union all select * from t use index()", + overlaySQL: "", + querySQL: "select * from t union all select * from t", + originSQL: "select * from `test` . `t` union all select * from `test` . `t`", + bindSQL: "SELECT * FROM `test`.`t` USE INDEX (`index_t`) UNION ALL SELECT * FROM `test`.`t` USE INDEX ()", + dropSQL: "binding for select * from t union all select * from t", + memoryUsage: float64(200), + }, + { + createSQL: "binding for (select * from t) union all (select * from t) using (select * from t use index(index_t)) union all (select * from t use index())", + overlaySQL: "", + querySQL: "(select * from t) union all (select * from t)", + originSQL: "( select * from `test` . `t` ) union all ( select * from `test` . `t` )", + bindSQL: "(SELECT * FROM `test`.`t` USE INDEX (`index_t`)) UNION ALL (SELECT * FROM `test`.`t` USE INDEX ())", + dropSQL: "binding for (select * from t) union all (select * from t)", + memoryUsage: float64(212), + }, + { + createSQL: "binding for select * from t intersect select * from t using select * from t use index(index_t) intersect select * from t use index()", + overlaySQL: "", + querySQL: "select * from t intersect select * from t", + originSQL: "select * from `test` . `t` intersect select * from `test` . `t`", + bindSQL: "SELECT * FROM `test`.`t` USE INDEX (`index_t`) INTERSECT SELECT * FROM `test`.`t` USE INDEX ()", + dropSQL: "binding for select * from t intersect select * from t", + memoryUsage: float64(200), + }, + { + createSQL: "binding for select * from t except select * from t using select * from t use index(index_t) except select * from t use index()", + overlaySQL: "", + querySQL: "select * from t except select * from t", + originSQL: "select * from `test` . `t` except select * from `test` . `t`", + bindSQL: "SELECT * FROM `test`.`t` USE INDEX (`index_t`) EXCEPT SELECT * FROM `test`.`t` USE INDEX ()", + dropSQL: "binding for select * from t except select * from t", + memoryUsage: float64(194), + }, + { + createSQL: "binding for select * from t using select /*+ use_index(t,index_t)*/ * from t", + overlaySQL: "", + querySQL: "select * from t ", + originSQL: "select * from `test` . `t`", + bindSQL: "SELECT /*+ use_index(`t` `index_t`)*/ * FROM `test`.`t`", + dropSQL: "binding for select * from t", + memoryUsage: float64(124), + }, + { + createSQL: "binding for delete from t where i = 1 using delete /*+ use_index(t,index_t) */ from t where i = 1", + overlaySQL: "", + querySQL: "delete from t where i = 2", + originSQL: "delete from `test` . `t` where `i` = ?", + bindSQL: "DELETE /*+ use_index(`t` `index_t`)*/ FROM `test`.`t` WHERE `i` = 1", + dropSQL: "binding for delete from t where i = 1", + memoryUsage: float64(148), + }, + { + createSQL: "binding for delete t, t1 from t inner join t1 on t.s = t1.s where t.i = 1 using delete /*+ use_index(t,index_t), hash_join(t,t1) */ t, t1 from t inner join t1 on t.s = t1.s where t.i = 1", + overlaySQL: "", + querySQL: "delete t, t1 from t inner join t1 on t.s = t1.s where t.i = 2", + originSQL: "delete `test` . `t` , `test` . `t1` from `test` . `t` join `test` . `t1` on `t` . `s` = `t1` . `s` where `t` . `i` = ?", + bindSQL: "DELETE /*+ use_index(`t` `index_t`) hash_join(`t`, `t1`)*/ `test`.`t`,`test`.`t1` FROM `test`.`t` JOIN `test`.`t1` ON `t`.`s` = `t1`.`s` WHERE `t`.`i` = 1", + dropSQL: "binding for delete t, t1 from t inner join t1 on t.s = t1.s where t.i = 1", + memoryUsage: float64(315), + }, + { + createSQL: "binding for update t set s = 'a' where i = 1 using update /*+ use_index(t,index_t) */ t set s = 'a' where i = 1", + overlaySQL: "", + querySQL: "update t set s='b' where i=2", + originSQL: "update `test` . `t` set `s` = ? where `i` = ?", + bindSQL: "UPDATE /*+ use_index(`t` `index_t`)*/ `test`.`t` SET `s`='a' WHERE `i` = 1", + dropSQL: "binding for update t set s = 'a' where i = 1", + memoryUsage: float64(162), + }, + { + createSQL: "binding for update t, t1 set t.s = 'a' where t.i = t1.i using update /*+ inl_join(t1) */ t, t1 set t.s = 'a' where t.i = t1.i", + overlaySQL: "", + querySQL: "update t , t1 set t.s='b' where t.i=t1.i", + originSQL: "update ( `test` . `t` ) join `test` . `t1` set `t` . `s` = ? where `t` . `i` = `t1` . `i`", + bindSQL: "UPDATE /*+ inl_join(`t1`)*/ (`test`.`t`) JOIN `test`.`t1` SET `t`.`s`='a' WHERE `t`.`i` = `t1`.`i`", + dropSQL: "binding for update t, t1 set t.s = 'a' where t.i = t1.i", + memoryUsage: float64(230), + }, + { + createSQL: "binding for insert into t1 select * from t where t.i = 1 using insert into t1 select /*+ use_index(t,index_t) */ * from t where t.i = 1", + overlaySQL: "", + querySQL: "insert into t1 select * from t where t.i = 2", + originSQL: "insert into `test` . `t1` select * from `test` . `t` where `t` . `i` = ?", + bindSQL: "INSERT INTO `test`.`t1` SELECT /*+ use_index(`t` `index_t`)*/ * FROM `test`.`t` WHERE `t`.`i` = 1", + dropSQL: "binding for insert into t1 select * from t where t.i = 1", + memoryUsage: float64(212), + }, + { + createSQL: "binding for replace into t1 select * from t where t.i = 1 using replace into t1 select /*+ use_index(t,index_t) */ * from t where t.i = 1", + overlaySQL: "", + querySQL: "replace into t1 select * from t where t.i = 2", + originSQL: "replace into `test` . `t1` select * from `test` . `t` where `t` . `i` = ?", + bindSQL: "REPLACE INTO `test`.`t1` SELECT /*+ use_index(`t` `index_t`)*/ * FROM `test`.`t` WHERE `t`.`i` = 1", + dropSQL: "binding for replace into t1 select * from t where t.i = 1", + memoryUsage: float64(214), + }, +} + func TestGlobalBinding(t *testing.T) { store, dom, clean := testkit.CreateMockStoreAndDomain(t) defer clean() diff --git a/bindinfo/session_handle_serial_test.go b/bindinfo/session_handle_serial_test.go index c3b41645ce9b0..ac890e7a3569d 100644 --- a/bindinfo/session_handle_serial_test.go +++ b/bindinfo/session_handle_serial_test.go @@ -16,6 +16,7 @@ package bindinfo_test import ( "context" + "crypto/tls" "strconv" "testing" "time" @@ -25,6 +26,7 @@ import ( "github.com/pingcap/tidb/errno" "github.com/pingcap/tidb/metrics" plannercore "github.com/pingcap/tidb/planner/core" + "github.com/pingcap/tidb/session/txninfo" "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/stmtsummary" @@ -364,6 +366,44 @@ func TestDefaultDB(t *testing.T) { tk.MustQuery("show session bindings").Check(testkit.Rows()) } +type mockSessionManager struct { + PS []*util.ProcessInfo +} + +func (msm *mockSessionManager) ShowTxnList() []*txninfo.TxnInfo { + panic("unimplemented!") +} + +func (msm *mockSessionManager) ShowProcessList() map[uint64]*util.ProcessInfo { + ret := make(map[uint64]*util.ProcessInfo) + for _, item := range msm.PS { + ret[item.ID] = item + } + return ret +} + +func (msm *mockSessionManager) GetProcessInfo(id uint64) (*util.ProcessInfo, bool) { + for _, item := range msm.PS { + if item.ID == id { + return item, true + } + } + return &util.ProcessInfo{}, false +} + +func (msm *mockSessionManager) Kill(cid uint64, query bool) { +} + +func (msm *mockSessionManager) KillAllConnections() { +} + +func (msm *mockSessionManager) UpdateTLSConfig(cfg *tls.Config) { +} + +func (msm *mockSessionManager) ServerID() uint64 { + return 1 +} + func TestIssue19836(t *testing.T) { store, clean := testkit.CreateMockStore(t) defer clean() diff --git a/br/pkg/kv/kv.go b/br/pkg/kv/kv.go index 229255b0d5867..cef7ff486c49f 100644 --- a/br/pkg/kv/kv.go +++ b/br/pkg/kv/kv.go @@ -16,6 +16,7 @@ package kv import ( "bytes" + "context" "fmt" "math" "sort" @@ -350,11 +351,11 @@ func (kvcodec *tableKVEncoder) AddRecord( incrementalBits-- } alloc := kvcodec.tbl.Allocators(kvcodec.se).Get(autoid.AutoRandomType) - _ = alloc.Rebase(value.GetInt64()&((1< 0.1 { - return false, fmt.Sprintf("The diff(%.2f) between actual(%.2f) and expect(%.2f) is too huge.", diff, actual, expect) - } - return true, "" -} - -func (s *testLoggingSuite) TestRater(c *C) { +func TestRater(t *testing.T) { + t.Parallel() m := prometheus.NewCounter(prometheus.CounterOpts{ Namespace: "testing", Name: "rater", @@ -87,19 +57,21 @@ func (s *testLoggingSuite) TestRater(c *C) { rater := logutil.TraceRateOver(m) timePass := time.Now() rater.Inc() - c.Assert(rater.RateAt(timePass.Add(100*time.Millisecond)), isAbout{}, 10.0) + require.InEpsilon(t, 10.0, rater.RateAt(timePass.Add(100*time.Millisecond)), 0.1) rater.Inc() - c.Assert(rater.RateAt(timePass.Add(150*time.Millisecond)), isAbout{}, 13.0) + require.InEpsilon(t, 13.0, rater.RateAt(timePass.Add(150*time.Millisecond)), 0.1) rater.Add(18) - c.Assert(rater.RateAt(timePass.Add(200*time.Millisecond)), isAbout{}, 100.0) + require.InEpsilon(t, 100.0, rater.RateAt(timePass.Add(200*time.Millisecond)), 0.1) } -func (s *testLoggingSuite) TestFile(c *C) { - assertTrimEqual(c, logutil.File(newFile(1)), +func TestFile(t *testing.T) { + t.Parallel() + assertTrimEqual(t, logutil.File(newFile(1)), `{"file": {"name": "1", "CF": "write", "sha256": "31", "startKey": "31", "endKey": "32", "startVersion": 1, "endVersion": 2, "totalKvs": 1, "totalBytes": 1, "CRC64Xor": 1}}`) } -func (s *testLoggingSuite) TestFiles(c *C) { +func TestFiles(t *testing.T) { + t.Parallel() cases := []struct { count int expect string @@ -119,18 +91,20 @@ func (s *testLoggingSuite) TestFiles(c *C) { for j := 0; j < cs.count; j++ { ranges[j] = newFile(j) } - assertTrimEqual(c, logutil.Files(ranges), cs.expect) + assertTrimEqual(t, logutil.Files(ranges), cs.expect) } } -func (s *testLoggingSuite) TestKey(c *C) { +func TestKey(t *testing.T) { + t.Parallel() encoder := zapcore.NewConsoleEncoder(zapcore.EncoderConfig{}) out, err := encoder.EncodeEntry(zapcore.Entry{}, []zap.Field{logutil.Key("test", []byte{0, 1, 2, 3})}) - c.Assert(err, IsNil) - c.Assert(strings.Trim(out.String(), "\n"), Equals, `{"test": "00010203"}`) + require.NoError(t, err) + require.JSONEq(t, `{"test": "00010203"}`, strings.Trim(out.String(), "\n")) } -func (s *testLoggingSuite) TestKeys(c *C) { +func TestKeys(t *testing.T) { + t.Parallel() cases := []struct { count int expect string @@ -150,11 +124,12 @@ func (s *testLoggingSuite) TestKeys(c *C) { for j := 0; j < cs.count; j++ { keys[j] = []byte(fmt.Sprintf("%04d", j)) } - assertTrimEqual(c, logutil.Keys(keys), cs.expect) + assertTrimEqual(t, logutil.Keys(keys), cs.expect) } } -func (s *testLoggingSuite) TestRewriteRule(c *C) { +func TestRewriteRule(t *testing.T) { + t.Parallel() rule := &import_sstpb.RewriteRule{ OldKeyPrefix: []byte("old"), NewKeyPrefix: []byte("new"), @@ -163,11 +138,12 @@ func (s *testLoggingSuite) TestRewriteRule(c *C) { encoder := zapcore.NewConsoleEncoder(zapcore.EncoderConfig{}) out, err := encoder.EncodeEntry(zapcore.Entry{}, []zap.Field{logutil.RewriteRule(rule)}) - c.Assert(err, IsNil) - c.Assert(strings.Trim(out.String(), "\n"), Equals, `{"rewriteRule": {"oldKeyPrefix": "6f6c64", "newKeyPrefix": "6e6577", "newTimestamp": 5592405}}`) + require.NoError(t, err) + require.JSONEq(t, `{"rewriteRule": {"oldKeyPrefix": "6f6c64", "newKeyPrefix": "6e6577", "newTimestamp": 5592405}}`, strings.Trim(out.String(), "\n")) } -func (s *testLoggingSuite) TestRegion(c *C) { +func TestRegion(t *testing.T) { + t.Parallel() region := &metapb.Region{ Id: 1, StartKey: []byte{0x00, 0x01}, @@ -176,17 +152,19 @@ func (s *testLoggingSuite) TestRegion(c *C) { Peers: []*metapb.Peer{{Id: 2, StoreId: 3}, {Id: 4, StoreId: 5}}, } - assertTrimEqual(c, logutil.Region(region), + assertTrimEqual(t, logutil.Region(region), `{"region": {"ID": 1, "startKey": "0001", "endKey": "0002", "epoch": "conf_ver:1 version:1 ", "peers": "id:2 store_id:3 ,id:4 store_id:5 "}}`) } -func (s *testLoggingSuite) TestLeader(c *C) { +func TestLeader(t *testing.T) { + t.Parallel() leader := &metapb.Peer{Id: 2, StoreId: 3} - assertTrimEqual(c, logutil.Leader(leader), `{"leader": "id:2 store_id:3 "}`) + assertTrimEqual(t, logutil.Leader(leader), `{"leader": "id:2 store_id:3 "}`) } -func (s *testLoggingSuite) TestSSTMeta(c *C) { +func TestSSTMeta(t *testing.T) { + t.Parallel() meta := &import_sstpb.SSTMeta{ Uuid: []byte("mock uuid"), Range: &import_sstpb.Range{ @@ -200,39 +178,19 @@ func (s *testLoggingSuite) TestSSTMeta(c *C) { RegionEpoch: &metapb.RegionEpoch{ConfVer: 1, Version: 1}, } - assertTrimEqual(c, logutil.SSTMeta(meta), + assertTrimEqual(t, logutil.SSTMeta(meta), `{"sstMeta": {"CF": "default", "endKeyExclusive": false, "CRC32": 5592405, "length": 1, "regionID": 1, "regionEpoch": "conf_ver:1 version:1 ", "startKey": "0001", "endKey": "0002", "UUID": "invalid UUID 6d6f636b2075756964"}}`) } -func (s *testLoggingSuite) TestShortError(c *C) { +func TestShortError(t *testing.T) { + t.Parallel() err := errors.Annotate(berrors.ErrInvalidArgument, "test") - assertTrimEqual(c, logutil.ShortError(err), `{"error": "test: [BR:Common:ErrInvalidArgument]invalid argument"}`) -} - -type FieldEquals struct{} - -func (f FieldEquals) Info() *CheckerInfo { - return &CheckerInfo{ - Name: "FieldEquals", - Params: []string{ - "expected", - "actual", - }, - } -} - -func (f FieldEquals) Check(params []interface{}, names []string) (result bool, err string) { - expected := params[0].(zap.Field) - actual := params[1].(zap.Field) - - if !expected.Equals(actual) { - return false, "Field not match." - } - return true, "" + assertTrimEqual(t, logutil.ShortError(err), `{"error": "test: [BR:Common:ErrInvalidArgument]invalid argument"}`) } -func (s *testLoggingSuite) TestContextual(c *C) { +func TestContextual(t *testing.T) { + t.Parallel() testCore, logs := observer.New(zap.InfoLevel) logutil.ResetGlobalLogger(zap.New(testCore)) @@ -244,15 +202,15 @@ func (s *testLoggingSuite) TestContextual(c *C) { l.Info("let's go!", zap.String("character", "solte")) observedLogs := logs.TakeAll() - checkLog(c, observedLogs[0], + checkLog(t, observedLogs[0], "going to take an adventure?", zap.Int("HP", 50), zap.Int("HP-MAX", 50), zap.String("character", "solte")) - checkLog(c, observedLogs[1], + checkLog(t, observedLogs[1], "let's go!", zap.Strings("friends", []string{"firo", "seren", "black"}), zap.String("character", "solte")) } -func checkLog(c *C, actual observer.LoggedEntry, message string, fields ...zap.Field) { - c.Assert(message, Equals, actual.Message) +func checkLog(t *testing.T, actual observer.LoggedEntry, message string, fields ...zap.Field) { + require.Equal(t, message, actual.Message) for i, f := range fields { - c.Assert(f, FieldEquals{}, actual.Context[i]) + require.Truef(t, f.Equals(actual.Context[i]), "Expected field(%+v) does not equal to actual one(%+v).", f, actual.Context[i]) } } diff --git a/br/pkg/pdutil/main_test.go b/br/pkg/pdutil/main_test.go new file mode 100644 index 0000000000000..861c3921a3eb3 --- /dev/null +++ b/br/pkg/pdutil/main_test.go @@ -0,0 +1,31 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pdutil + +import ( + "testing" + + "github.com/pingcap/tidb/util/testbridge" + "go.uber.org/goleak" +) + +func TestMain(m *testing.M) { + testbridge.WorkaroundGoCheckFlags() + opts := []goleak.Option{ + goleak.IgnoreTopFunction("go.etcd.io/etcd/pkg/logutil.(*MergeLogger).outputLoop"), + goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"), + } + goleak.VerifyTestMain(m, opts...) +} diff --git a/br/pkg/pdutil/pd.go b/br/pkg/pdutil/pd.go index 5610578c0f766..2f898d9c062ef 100644 --- a/br/pkg/pdutil/pd.go +++ b/br/pkg/pdutil/pd.go @@ -156,14 +156,17 @@ func pdRequest( if count > pdRequestRetryTime || resp.StatusCode < 500 { break } - resp.Body.Close() - time.Sleep(time.Second) + _ = resp.Body.Close() + time.Sleep(pdRequestRetryInterval()) resp, err = cli.Do(req) if err != nil { return nil, errors.Trace(err) } } - defer resp.Body.Close() + defer func() { + _ = resp.Body.Close() + }() + if resp.StatusCode != http.StatusOK { res, _ := io.ReadAll(resp.Body) return nil, errors.Annotatef(berrors.ErrPDInvalidResponse, "[%d] %s %s", resp.StatusCode, res, reqURL) @@ -176,6 +179,15 @@ func pdRequest( return r, nil } +func pdRequestRetryInterval() time.Duration { + failpoint.Inject("FastRetry", func(v failpoint.Value) { + if v.(bool) { + failpoint.Return(0) + } + }) + return time.Second +} + // PdController manage get/update config from pd. type PdController struct { addrs []string diff --git a/br/pkg/pdutil/pd_test.go b/br/pkg/pdutil/pd_serial_test.go similarity index 72% rename from br/pkg/pdutil/pd_test.go rename to br/pkg/pdutil/pd_serial_test.go index e4e82d412171c..2dde535cd54b9 100644 --- a/br/pkg/pdutil/pd_test.go +++ b/br/pkg/pdutil/pd_serial_test.go @@ -15,26 +15,19 @@ import ( "testing" "github.com/coreos/go-semver/semver" - . "github.com/pingcap/check" + "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/tidb/util/codec" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/typeutil" "github.com/tikv/pd/server/api" "github.com/tikv/pd/server/core" "github.com/tikv/pd/server/statistics" ) -func TestT(t *testing.T) { - TestingT(t) -} - -type testPDControllerSuite struct { -} - -var _ = Suite(&testPDControllerSuite{}) - -func (s *testPDControllerSuite) TestScheduler(c *C) { - ctx := context.Background() +func TestScheduler(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() scheduler := "balance-leader-scheduler" mock := func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) { @@ -44,13 +37,13 @@ func (s *testPDControllerSuite) TestScheduler(c *C) { pdController := &PdController{addrs: []string{"", ""}, schedulerPauseCh: schedulerPauseCh} _, err := pdController.pauseSchedulersAndConfigWith(ctx, []string{scheduler}, nil, mock) - c.Assert(err, ErrorMatches, "failed") + require.EqualError(t, err, "failed") go func() { <-schedulerPauseCh }() err = pdController.resumeSchedulerWith(ctx, []string{scheduler}, mock) - c.Assert(err, IsNil) + require.NoError(t, err) cfg := map[string]interface{}{ "max-merge-region-keys": 0, @@ -59,34 +52,37 @@ func (s *testPDControllerSuite) TestScheduler(c *C) { "max-pending-peer-count": uint64(16), } _, err = pdController.pauseSchedulersAndConfigWith(ctx, []string{}, cfg, mock) - c.Assert(err, ErrorMatches, "failed to update PD.*") + require.Error(t, err) + require.Regexp(t, "^failed to update PD.*", err.Error()) go func() { <-schedulerPauseCh }() + err = pdController.resumeSchedulerWith(ctx, []string{scheduler}, mock) + require.NoError(t, err) _, err = pdController.listSchedulersWith(ctx, mock) - c.Assert(err, ErrorMatches, "failed") + require.EqualError(t, err, "failed") mock = func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) { return []byte(`["` + scheduler + `"]`), nil } _, err = pdController.pauseSchedulersAndConfigWith(ctx, []string{scheduler}, cfg, mock) - c.Assert(err, IsNil) + require.NoError(t, err) go func() { <-schedulerPauseCh }() err = pdController.resumeSchedulerWith(ctx, []string{scheduler}, mock) - c.Assert(err, IsNil) + require.NoError(t, err) schedulers, err := pdController.listSchedulersWith(ctx, mock) - c.Assert(err, IsNil) - c.Assert(schedulers, HasLen, 1) - c.Assert(schedulers[0], Equals, scheduler) + require.NoError(t, err) + require.Len(t, schedulers, 1) + require.Equal(t, scheduler, schedulers[0]) } -func (s *testPDControllerSuite) TestGetClusterVersion(c *C) { +func TestGetClusterVersion(t *testing.T) { pdController := &PdController{addrs: []string{"", ""}} // two endpoints counter := 0 mock := func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) { @@ -99,17 +95,17 @@ func (s *testPDControllerSuite) TestGetClusterVersion(c *C) { ctx := context.Background() respString, err := pdController.getClusterVersionWith(ctx, mock) - c.Assert(err, IsNil) - c.Assert(respString, Equals, "test") + require.NoError(t, err) + require.Equal(t, "test", respString) mock = func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) { return nil, errors.New("mock error") } _, err = pdController.getClusterVersionWith(ctx, mock) - c.Assert(err, NotNil) + require.Error(t, err) } -func (s *testPDControllerSuite) TestRegionCount(c *C) { +func TestRegionCount(t *testing.T) { regions := core.NewRegionsInfo() regions.SetRegion(core.NewRegionInfo(&metapb.Region{ Id: 1, @@ -129,55 +125,61 @@ func (s *testPDControllerSuite) TestRegionCount(c *C) { EndKey: codec.EncodeBytes(nil, []byte{3, 4}), RegionEpoch: &metapb.RegionEpoch{}, }, nil)) - c.Assert(regions.Len(), Equals, 3) + require.Equal(t, 3, regions.Len()) mock := func( _ context.Context, addr string, prefix string, _ *http.Client, _ string, _ io.Reader, ) ([]byte, error) { query := fmt.Sprintf("%s/%s", addr, prefix) u, e := url.Parse(query) - c.Assert(e, IsNil, Commentf("%s", query)) + require.NoError(t, e, query) start := u.Query().Get("start_key") end := u.Query().Get("end_key") - c.Log(hex.EncodeToString([]byte(start))) - c.Log(hex.EncodeToString([]byte(end))) + t.Log(hex.EncodeToString([]byte(start))) + t.Log(hex.EncodeToString([]byte(end))) scanRegions := regions.ScanRange([]byte(start), []byte(end), 0) stats := statistics.RegionStats{Count: len(scanRegions)} ret, err := json.Marshal(stats) - c.Assert(err, IsNil) + require.NoError(t, err) return ret, nil } pdController := &PdController{addrs: []string{"http://mock"}} ctx := context.Background() resp, err := pdController.getRegionCountWith(ctx, mock, []byte{}, []byte{}) - c.Assert(err, IsNil) - c.Assert(resp, Equals, 3) + require.NoError(t, err) + require.Equal(t, 3, resp) resp, err = pdController.getRegionCountWith(ctx, mock, []byte{0}, []byte{0xff}) - c.Assert(err, IsNil) - c.Assert(resp, Equals, 3) + require.NoError(t, err) + require.Equal(t, 3, resp) resp, err = pdController.getRegionCountWith(ctx, mock, []byte{1, 2}, []byte{1, 4}) - c.Assert(err, IsNil) - c.Assert(resp, Equals, 2) + require.NoError(t, err) + require.Equal(t, 2, resp) } -func (s *testPDControllerSuite) TestPDVersion(c *C) { +func TestPDVersion(t *testing.T) { v := []byte("\"v4.1.0-alpha1\"\n") r := parseVersion(v) expectV := semver.New("4.1.0-alpha1") - c.Assert(r.Major, Equals, expectV.Major) - c.Assert(r.Minor, Equals, expectV.Minor) - c.Assert(r.PreRelease, Equals, expectV.PreRelease) + require.Equal(t, expectV.Major, r.Major) + require.Equal(t, expectV.Minor, r.Minor) + require.Equal(t, expectV.PreRelease, r.PreRelease) } -func (s *testPDControllerSuite) TestPDRequestRetry(c *C) { +func TestPDRequestRetry(t *testing.T) { ctx := context.Background() + + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/br/pkg/pdutil/FastRetry", "return(true)")) + defer func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/br/pkg/pdutil/FastRetry")) + }() + count := 0 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { count++ - if count <= 5 { + if count <= pdRequestRetryTime-1 { w.WriteHeader(http.StatusGatewayTimeout) return } @@ -186,12 +188,12 @@ func (s *testPDControllerSuite) TestPDRequestRetry(c *C) { cli := http.DefaultClient taddr := ts.URL _, reqErr := pdRequest(ctx, taddr, "", cli, http.MethodGet, nil) - c.Assert(reqErr, IsNil) + require.NoError(t, reqErr) ts.Close() count = 0 ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { count++ - if count <= 11 { + if count <= pdRequestRetryTime+1 { w.WriteHeader(http.StatusGatewayTimeout) return } @@ -200,10 +202,10 @@ func (s *testPDControllerSuite) TestPDRequestRetry(c *C) { defer ts.Close() taddr = ts.URL _, reqErr = pdRequest(ctx, taddr, "", cli, http.MethodGet, nil) - c.Assert(reqErr, NotNil) + require.Error(t, reqErr) } -func (s *testPDControllerSuite) TestStoreInfo(c *C) { +func TestStoreInfo(t *testing.T) { storeInfo := api.StoreInfo{ Status: &api.StoreStatus{ Capacity: typeutil.ByteSize(1024), @@ -217,18 +219,18 @@ func (s *testPDControllerSuite) TestStoreInfo(c *C) { _ context.Context, addr string, prefix string, _ *http.Client, _ string, _ io.Reader, ) ([]byte, error) { query := fmt.Sprintf("%s/%s", addr, prefix) - c.Assert(query, Equals, "http://mock/pd/api/v1/store/1") + require.Equal(t, "http://mock/pd/api/v1/store/1", query) ret, err := json.Marshal(storeInfo) - c.Assert(err, IsNil) + require.NoError(t, err) return ret, nil } pdController := &PdController{addrs: []string{"http://mock"}} ctx := context.Background() resp, err := pdController.getStoreInfoWith(ctx, mock, 1) - c.Assert(err, IsNil) - c.Assert(resp, NotNil) - c.Assert(resp.Status, NotNil) - c.Assert(resp.Store.StateName, Equals, "Tombstone") - c.Assert(uint64(resp.Status.Available), Equals, uint64(1024)) + require.NoError(t, err) + require.NotNil(t, resp) + require.NotNil(t, resp.Status) + require.Equal(t, "Tombstone", resp.Store.StateName) + require.Equal(t, uint64(1024), uint64(resp.Status.Available)) } diff --git a/br/tests/lightning_checkpoint/config.toml b/br/tests/lightning_checkpoint/config.toml index 7d9a423e542af..e4595a6e0f045 100644 --- a/br/tests/lightning_checkpoint/config.toml +++ b/br/tests/lightning_checkpoint/config.toml @@ -5,7 +5,7 @@ table-concurrency = 1 enable = true schema = "tidb_lightning_checkpoint_test_cppk" driver = "mysql" -keep-after-success = true +keep-after-success = "origin" [mydumper] read-block-size = 1 diff --git a/br/tests/lightning_checkpoint/run.sh b/br/tests/lightning_checkpoint/run.sh index ed5f36a706912..41513ba575fc6 100755 --- a/br/tests/lightning_checkpoint/run.sh +++ b/br/tests/lightning_checkpoint/run.sh @@ -111,7 +111,7 @@ for BACKEND in importer local; do run_lightning -d "$DBPATH" --backend $BACKEND --enable-checkpoint=1 run_sql "$PARTIAL_IMPORT_QUERY" check_contains "s: $(( (1000 * $CHUNK_COUNT + 1001) * $CHUNK_COUNT * $TABLE_COUNT ))" - run_sql 'SELECT count(*) FROM `tidb_lightning_checkpoint_test_cppk.1357924680.bak`.table_v7 WHERE status >= 200' + run_sql 'SELECT count(*) FROM `tidb_lightning_checkpoint_test_cppk`.table_v7 WHERE status >= 200' check_contains "count(*): $TABLE_COUNT" # Ensure there is no dangling open engines diff --git a/br/tidb-lightning.toml b/br/tidb-lightning.toml index 2f26b6bcb2d8c..e28a5e5bbffc9 100644 --- a/br/tidb-lightning.toml +++ b/br/tidb-lightning.toml @@ -82,9 +82,12 @@ driver = "file" # For "mysql" driver, the DSN is a URL in the form "USER:PASS@tcp(HOST:PORT)/". # If not specified, the TiDB server from the [tidb] section will be used to store the checkpoints. #dsn = "/tmp/tidb_lightning_checkpoint.pb" -# Whether to keep the checkpoints after all data are imported. If false, the checkpoints will be deleted. The schema -# needs to be dropped manually, however. -#keep-after-success = false +# Whether to keep the checkpoints after all data are imported. +# valid options: +# - remove(default). the checkpoints will be deleted +# - rename. the checkpoints data will be kept, but will change the checkpoint data schema name with `schema.{taskID}.bak` +# - origin. keep the checkpoints data unchanged. +#keep-after-success = "remove" [tikv-importer] # Delivery backend, can be "importer", "local" or "tidb". diff --git a/ddl/column.go b/ddl/column.go index 0f87d240b4bc4..a459f2c8af0c7 100644 --- a/ddl/column.go +++ b/ddl/column.go @@ -1654,7 +1654,7 @@ func applyNewAutoRandomBits(d *ddlCtx, m *meta.Meta, dbInfo *model.DBInfo, if err != nil { return errors.Trace(err) } - err = autoRandAlloc.Rebase(nextAutoIncID, false) + err = autoRandAlloc.Rebase(context.Background(), nextAutoIncID, false) if err != nil { return errors.Trace(err) } diff --git a/ddl/db_test.go b/ddl/db_test.go index 2507a56f19549..3a5b5cb48a4e4 100644 --- a/ddl/db_test.go +++ b/ddl/db_test.go @@ -2862,58 +2862,6 @@ func (s *testSerialDBSuite) TestCreateTableWithLike2(c *C) { c.Assert(t1.Meta().TiFlashReplica.AvailablePartitionIDs, DeepEquals, []int64{partition.Definitions[0].ID, partition.Definitions[1].ID}) } -func (s *testSerialDBSuite) TestCreateTableWithSpecialComment(c *C) { - tk := testkit.NewTestKit(c, s.store) - tk.MustExec("use test") - - // case for direct options - tk.MustExec(`DROP TABLE IF EXISTS t`) - tk.MustExec("CREATE TABLE `t` (\n" + - " `a` int(11) DEFAULT NULL\n" + - ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin " + - "/*T![placement] PRIMARY_REGION=\"cn-east-1\" " + - "REGIONS=\"cn-east-1, cn-east-2\" " + - "FOLLOWERS=2 " + - "CONSTRAINTS=\"[+disk=ssd]\" */", - ) - tk.MustQuery(`show create table t`).Check(testutil.RowsWithSep("|", - "t CREATE TABLE `t` (\n"+ - " `a` int(11) DEFAULT NULL\n"+ - ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin "+ - "/*T![placement] PRIMARY_REGION=\"cn-east-1\" "+ - "REGIONS=\"cn-east-1, cn-east-2\" "+ - "FOLLOWERS=2 "+ - "CONSTRAINTS=\"[+disk=ssd]\" */", - )) - - // case for policy - tk.MustExec(`DROP TABLE IF EXISTS t`) - tk.MustExec("create placement policy x " + - "PRIMARY_REGION=\"cn-east-1\" " + - "REGIONS=\"cn-east-1, cn-east-2\" " + - "FOLLOWERS=2 " + - "CONSTRAINTS=\"[+disk=ssd]\" ") - tk.MustExec("create table t(a int)" + - "/*T![placement] PLACEMENT POLICY=`x` */") - tk.MustQuery(`show create table t`).Check(testutil.RowsWithSep("|", - "t CREATE TABLE `t` (\n"+ - " `a` int(11) DEFAULT NULL\n"+ - ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin "+ - "/*T![placement] PLACEMENT POLICY=`x` */", - )) - - // case for policy with quotes - tk.MustExec(`DROP TABLE IF EXISTS t`) - tk.MustExec("create table t(a int)" + - "/*T![placement] PLACEMENT POLICY=\"x\" */") - tk.MustQuery(`show create table t`).Check(testutil.RowsWithSep("|", - "t CREATE TABLE `t` (\n"+ - " `a` int(11) DEFAULT NULL\n"+ - ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin "+ - "/*T![placement] PLACEMENT POLICY=`x` */", - )) -} - func (s *testSerialDBSuite) TestCreateTable(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") diff --git a/ddl/ddl_api.go b/ddl/ddl_api.go index 6715fd1a3b666..1474cc7d5a14a 100644 --- a/ddl/ddl_api.go +++ b/ddl/ddl_api.go @@ -2283,7 +2283,7 @@ func checkCharsetAndCollation(cs string, co string) error { func (d *ddl) handleAutoIncID(tbInfo *model.TableInfo, schemaID int64, newEnd int64, tp autoid.AllocatorType) error { allocs := autoid.NewAllocatorsFromTblInfo(d.store, schemaID, tbInfo) if alloc := allocs.Get(tp); alloc != nil { - err := alloc.Rebase(newEnd, false) + err := alloc.Rebase(context.Background(), newEnd, false) if err != nil { return errors.Trace(err) } @@ -2732,6 +2732,8 @@ func (d *ddl) AlterTable(ctx context.Context, sctx sessionctx.Context, ident ast err = d.AlterTableAttributes(sctx, ident, spec) case ast.AlterTablePartitionAttributes: err = d.AlterTablePartitionAttributes(sctx, ident, spec) + case ast.AlterTablePartitionOptions: + err = d.AlterTablePartitionOptions(sctx, ident, spec) default: // Nothing to do now. } @@ -6189,10 +6191,9 @@ func (d *ddl) AlterTableAlterPartition(ctx sessionctx.Context, ident ast.Ident, return errors.Trace(err) } + // TODO: the old placement rules should be migrated to new format. use the bundle from meta directly. bundle := infoschema.GetBundle(d.infoCache.GetLatest(), []int64{partitionID, meta.ID, schema.ID}) - bundle.ID = placement.GroupID(partitionID) - err = bundle.ApplyPlacementSpec(spec.PlacementSpecs) if err != nil { var sb strings.Builder @@ -6210,15 +6211,7 @@ func (d *ddl) AlterTableAlterPartition(ctx sessionctx.Context, ident ast.Ident, if err != nil { return errors.Trace(err) } - bundle.Reset(partitionID) - - if len(bundle.Rules) == 0 { - bundle.Index = 0 - bundle.Override = false - } else { - bundle.Index = placement.RuleIndexPartition - bundle.Override = true - } + bundle.Reset(placement.RuleIndexPartition, []int64{partitionID}) job := &model.Job{ SchemaID: schema.ID, @@ -6324,6 +6317,81 @@ func (d *ddl) AlterTablePartitionAttributes(ctx sessionctx.Context, ident ast.Id return errors.Trace(err) } +func (d *ddl) AlterTablePartitionOptions(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) (err error) { + schema, tb, err := d.getSchemaAndTableByIdent(ctx, ident) + if err != nil { + return errors.Trace(err) + } + + meta := tb.Meta() + if meta.Partition == nil { + return errors.Trace(ErrPartitionMgmtOnNonpartitioned) + } + + partitionID, err := tables.FindPartitionByName(meta, spec.PartitionNames[0].L) + if err != nil { + return errors.Trace(err) + } + var policyRefInfo *model.PolicyRefInfo + var placementSettings *model.PlacementSettings + if spec.Options != nil { + for _, op := range spec.Options { + switch op.Tp { + case ast.TableOptionPlacementPolicy: + policyRefInfo = &model.PolicyRefInfo{ + Name: model.NewCIStr(op.StrValue), + } + case ast.TableOptionPlacementPrimaryRegion, ast.TableOptionPlacementRegions, + ast.TableOptionPlacementFollowerCount, ast.TableOptionPlacementVoterCount, + ast.TableOptionPlacementLearnerCount, ast.TableOptionPlacementSchedule, + ast.TableOptionPlacementConstraints, ast.TableOptionPlacementLeaderConstraints, + ast.TableOptionPlacementLearnerConstraints, ast.TableOptionPlacementFollowerConstraints, + ast.TableOptionPlacementVoterConstraints: + if placementSettings == nil { + placementSettings = &model.PlacementSettings{} + } + err = SetDirectPlacementOpt(placementSettings, ast.PlacementOptionType(op.Tp), op.StrValue, op.UintValue) + if err != nil { + return err + } + default: + return errors.Trace(errors.New("unknown partition option")) + } + } + } + + // Can not use both a placement policy and direct assignment. If you specify both in a CREATE TABLE or ALTER TABLE an error will be returned. + if placementSettings != nil && policyRefInfo != nil { + return errors.Trace(ErrPlacementPolicyWithDirectOption.GenWithStackByArgs(policyRefInfo.Name)) + } + if placementSettings != nil { + // check the direct placement option compatibility. + if err := checkPolicyValidation(placementSettings); err != nil { + return errors.Trace(err) + } + } + if policyRefInfo != nil { + policy, ok := ctx.GetInfoSchema().(infoschema.InfoSchema).PolicyByName(policyRefInfo.Name) + if !ok { + return errors.Trace(infoschema.ErrPlacementPolicyNotExists.GenWithStackByArgs(policyRefInfo.Name)) + } + policyRefInfo.ID = policy.ID + } + + job := &model.Job{ + SchemaID: schema.ID, + TableID: meta.ID, + SchemaName: schema.Name.L, + Type: model.ActionAlterTablePartitionPolicy, + BinlogInfo: &model.HistoryInfo{}, + Args: []interface{}{partitionID, policyRefInfo, placementSettings}, + } + + err = d.doDDLJob(ctx, job) + err = d.callHookOnChanged(err) + return errors.Trace(err) +} + func buildPolicyInfo(name model.CIStr, options []*ast.PlacementOption) (*model.PolicyInfo, error) { policyInfo := &model.PolicyInfo{PlacementSettings: &model.PlacementSettings{}} policyInfo.Name = name diff --git a/ddl/ddl_worker.go b/ddl/ddl_worker.go index a7c253b5ddbdf..004b09397374e 100644 --- a/ddl/ddl_worker.go +++ b/ddl/ddl_worker.go @@ -832,6 +832,8 @@ func (w *worker) runDDLJob(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, ver, err = onDropPlacementPolicy(d, t, job) case model.ActionAlterPlacementPolicy: ver, err = onAlterPlacementPolicy(d, t, job) + case model.ActionAlterTablePartitionPolicy: + ver, err = onAlterTablePartitionOptions(t, job) default: // Invalid job, cancel it. job.State = model.JobStateCancelled diff --git a/ddl/partition.go b/ddl/partition.go index 7b581ec1c937b..7dfb4dfd9f379 100644 --- a/ddl/partition.go +++ b/ddl/partition.go @@ -487,13 +487,44 @@ func buildRangePartitionDefinitions(ctx sessionctx.Context, defs []*ast.Partitio } } comment, _ := def.Comment() + var directPlacementOpts *model.PlacementSettings + var placementPolicyRef *model.PolicyRefInfo + // the partition inheritance of placement rules don't have to copy the placement elements to themselves. + // For example: + // t placement policy x (p1 placement policy y, p2) + // p2 will share the same rule as table t does, but it won't copy the meta to itself. we will + // append p2 range to the coverage of table t's rules. This mechanism is good for cascading change + // when policy x is altered. + for _, opt := range def.Options { + switch opt.Tp { + case ast.TableOptionPlacementPrimaryRegion, ast.TableOptionPlacementRegions, + ast.TableOptionPlacementFollowerCount, ast.TableOptionPlacementVoterCount, + ast.TableOptionPlacementLearnerCount, ast.TableOptionPlacementSchedule, + ast.TableOptionPlacementConstraints, ast.TableOptionPlacementLeaderConstraints, + ast.TableOptionPlacementLearnerConstraints, ast.TableOptionPlacementFollowerConstraints, + ast.TableOptionPlacementVoterConstraints: + if directPlacementOpts == nil { + directPlacementOpts = &model.PlacementSettings{} + } + err := SetDirectPlacementOpt(directPlacementOpts, ast.PlacementOptionType(opt.Tp), opt.StrValue, opt.UintValue) + if err != nil { + return nil, err + } + case ast.TableOptionPlacementPolicy: + placementPolicyRef = &model.PolicyRefInfo{ + Name: model.NewCIStr(opt.StrValue), + } + } + } err := checkTooLongTable(def.Name) if err != nil { return nil, err } piDef := model.PartitionDefinition{ - Name: def.Name, - Comment: comment, + Name: def.Name, + Comment: comment, + DirectPlacementOpts: directPlacementOpts, + PlacementPolicyRef: placementPolicyRef, } buf := new(bytes.Buffer) @@ -1124,7 +1155,7 @@ func onTruncateTablePartition(d *ddlCtx, t *meta.Meta, job *model.Job) (int64, e oldBundle, ok := d.infoCache.GetLatest().BundleByName(placement.GroupID(oldID)) if ok && !oldBundle.IsEmpty() { bundles = append(bundles, placement.NewBundle(oldID)) - bundles = append(bundles, oldBundle.Clone().Reset(newPartitions[i].ID)) + bundles = append(bundles, oldBundle.Clone().Reset(placement.RuleIndexPartition, []int64{newPartitions[i].ID})) } } @@ -1331,14 +1362,14 @@ func (w *worker) onExchangeTablePartition(d *ddlCtx, t *meta.Meta, job *model.Jo ntBundle, ntOK := d.infoCache.GetLatest().BundleByName(placement.GroupID(nt.ID)) ntOK = ntOK && !ntBundle.IsEmpty() if ptOK && ntOK { - bundles = append(bundles, ptBundle.Clone().Reset(nt.ID)) - bundles = append(bundles, ntBundle.Clone().Reset(partDef.ID)) + bundles = append(bundles, ptBundle.Clone().Reset(placement.RuleIndexPartition, []int64{nt.ID})) + bundles = append(bundles, ntBundle.Clone().Reset(placement.RuleIndexPartition, []int64{partDef.ID})) } else if ptOK { bundles = append(bundles, placement.NewBundle(partDef.ID)) - bundles = append(bundles, ptBundle.Clone().Reset(nt.ID)) + bundles = append(bundles, ptBundle.Clone().Reset(placement.RuleIndexPartition, []int64{nt.ID})) } else if ntOK { bundles = append(bundles, placement.NewBundle(nt.ID)) - bundles = append(bundles, ntBundle.Clone().Reset(partDef.ID)) + bundles = append(bundles, ntBundle.Clone().Reset(placement.RuleIndexPartition, []int64{partDef.ID})) } err = infosync.PutRuleBundles(context.TODO(), bundles) if err != nil { diff --git a/ddl/placement/bundle.go b/ddl/placement/bundle.go index 3994e04ad3ca3..e132d21242c02 100644 --- a/ddl/placement/bundle.go +++ b/ddl/placement/bundle.go @@ -20,6 +20,7 @@ import ( "errors" "fmt" "math" + "sort" "strconv" "strings" @@ -65,9 +66,7 @@ func NewBundleFromConstraintsOptions(options *model.PlacementSettings) (*Bundle, leaderConstraints := options.LeaderConstraints learnerConstraints := options.LearnerConstraints followerConstraints := options.FollowerConstraints - voterConstraints := options.VoterConstraints followerCount := options.Followers - voterCount := options.Voters learnerCount := options.Learners CommonConstraints, err := NewConstraintsFromYaml([]byte(constraints)) @@ -90,21 +89,6 @@ func NewBundleFromConstraintsOptions(options *model.PlacementSettings) (*Bundle, Rules = append(Rules, NewRule(Leader, 1, LeaderConstraints)) } - if voterCount > 0 { - VoterRules, err := NewRules(Voter, voterCount, voterConstraints) - if err != nil { - return nil, fmt.Errorf("%w: invalid VoterConstraints", err) - } - for _, rule := range VoterRules { - for _, cnst := range CommonConstraints { - if err := rule.Constraints.Add(cnst); err != nil { - return nil, fmt.Errorf("%w: VoterConstraints conflicts with Constraints", err) - } - } - } - Rules = append(Rules, VoterRules...) - } - if followerCount > 0 { FollowerRules, err := NewRules(Follower, followerCount, followerConstraints) if err != nil { @@ -144,7 +128,7 @@ func NewBundleFromSugarOptions(options *model.PlacementSettings) (*Bundle, error return nil, fmt.Errorf("%w: options can not be nil", ErrInvalidPlacementOptions) } - if len(options.LeaderConstraints) > 0 || len(options.LearnerConstraints) > 0 || len(options.FollowerConstraints) > 0 || len(options.VoterConstraints) > 0 || options.Learners > 0 || options.Voters > 0 { + if len(options.LeaderConstraints) > 0 || len(options.LearnerConstraints) > 0 || len(options.FollowerConstraints) > 0 || len(options.Constraints) > 0 || options.Learners > 0 { return nil, fmt.Errorf("%w: should be PRIMARY_REGION=.. REGIONS=.. FOLLOWERS=.. SCHEDULE=.., mixed other constraints into options %s", ErrInvalidPlacementOptions, options) } @@ -167,65 +151,45 @@ func NewBundleFromSugarOptions(options *model.PlacementSettings) (*Bundle, error } schedule := options.Schedule - var constraints Constraints - var err error + // regions must include the primary + sort.Strings(regions) + primaryIndex := sort.SearchStrings(regions, primaryRegion) + if primaryIndex >= len(regions) || regions[primaryIndex] != primaryRegion { + return nil, fmt.Errorf("%w: primary region must be included in regions", ErrInvalidPlacementOptions) + } + + var Rules []*Rule - Rules := []*Rule{} switch strings.ToLower(schedule) { case "", "even": - constraints, err = NewConstraints([]string{fmt.Sprintf("+region=%s", primaryRegion)}) - if err != nil { - return nil, fmt.Errorf("%w: invalid PrimaryRegion '%s'", err, primaryRegion) + primaryCount := uint64(math.Ceil(float64(followers+1) / float64(len(regions)))) + Rules = append(Rules, NewRule(Voter, primaryCount, NewConstraintsDirect(NewConstraintDirect("region", In, primaryRegion)))) + + if len(regions) > 1 { + // delete primary from regions + regions = regions[:primaryIndex+copy(regions[primaryIndex:], regions[primaryIndex+1:])] + Rules = append(Rules, NewRule(Follower, followers+1-primaryCount, NewConstraintsDirect(NewConstraintDirect("region", In, regions...)))) } - Rules = append(Rules, NewRule(Leader, 1, constraints)) case "majority_in_primary": - // We already have the leader, so we need to calculate how many additional followers - // need to be in the primary region for quorum - followersInPrimary := uint64(math.Ceil(float64(followers) / 2)) - constraints, err = NewConstraints([]string{fmt.Sprintf("+region=%s", primaryRegion)}) - if err != nil { - return nil, fmt.Errorf("%w: invalid PrimaryRegion, '%s'", err, primaryRegion) + // calculate how many replicas need to be in the primary region for quorum + primaryCount := uint64(math.Ceil(float64(followers+1)/2 + 1)) + Rules = append(Rules, NewRule(Voter, primaryCount, NewConstraintsDirect(NewConstraintDirect("region", In, primaryRegion)))) + + if len(regions) > 1 { + // delete primary from regions + regions = regions[:primaryIndex+copy(regions[primaryIndex:], regions[primaryIndex+1:])] + Rules = append(Rules, NewRule(Follower, followers+1-primaryCount, NewConstraintsDirect(NewConstraintDirect("region", In, regions...)))) } - Rules = append(Rules, NewRule(Leader, 1, constraints)) - Rules = append(Rules, NewRule(Follower, followersInPrimary, constraints)) - // even split the remaining followers - followers = followers - followersInPrimary default: return nil, fmt.Errorf("%w: unsupported schedule %s", ErrInvalidPlacementOptions, schedule) } - if uint64(len(regions)) > followers { - return nil, fmt.Errorf("%w: remain %d region to schedule, only %d follower left", ErrInvalidPlacementOptions, uint64(len(regions)), followers) - } - - if len(regions) == 0 { - constraints, err := NewConstraints(nil) - if err != nil { - return nil, err - } - Rules = append(Rules, NewRule(Follower, followers, constraints)) - } else { - count := followers / uint64(len(regions)) - rem := followers - count*uint64(len(regions)) - for _, region := range regions { - constraints, err = NewConstraints([]string{fmt.Sprintf("+region=%s", region)}) - if err != nil { - return nil, fmt.Errorf("%w: invalid region of 'Regions', '%s'", err, region) - } - replica := count - if rem > 0 { - replica += 1 - rem-- - } - Rules = append(Rules, NewRule(Follower, replica, constraints)) - } - } - return &Bundle{Rules: Rules}, nil } -// NewBundleFromOptions will transform options into the bundle. -func NewBundleFromOptions(options *model.PlacementSettings) (*Bundle, error) { +// Non-Exported functionality function, do not use it directly but NewBundleFromOptions +// here is for only directly used in the test. +func newBundleFromOptions(options *model.PlacementSettings) (bundle *Bundle, err error) { var isSyntaxSugar bool if options == nil { @@ -237,9 +201,27 @@ func NewBundleFromOptions(options *model.PlacementSettings) (*Bundle, error) { } if isSyntaxSugar { - return NewBundleFromSugarOptions(options) + bundle, err = NewBundleFromSugarOptions(options) + } else { + bundle, err = NewBundleFromConstraintsOptions(options) + } + return bundle, err +} + +// NewBundleFromOptions will transform options into the bundle. +func NewBundleFromOptions(options *model.PlacementSettings) (bundle *Bundle, err error) { + bundle, err = newBundleFromOptions(options) + if err != nil { + return nil, err + } + if bundle == nil { + return nil, nil + } + err = bundle.Tidy() + if err != nil { + return nil, err } - return NewBundleFromConstraintsOptions(options) + return bundle, err } // ApplyPlacementSpec will apply actions defined in PlacementSpec to the bundle. @@ -363,16 +345,57 @@ func (b *Bundle) Tidy() error { } // Reset resets the bundle ID and keyrange of all rules. -func (b *Bundle) Reset(newID int64) *Bundle { - b.ID = GroupID(newID) - // Involve all the table level objects. - startKey := hex.EncodeToString(codec.EncodeBytes(nil, tablecodec.GenTablePrefix(newID))) - endKey := hex.EncodeToString(codec.EncodeBytes(nil, tablecodec.GenTablePrefix(newID+1))) - for _, rule := range b.Rules { - rule.GroupID = b.ID - rule.StartKeyHex = startKey - rule.EndKeyHex = endKey +func (b *Bundle) Reset(ruleIndex int, newIDs []int64) *Bundle { + // eliminate the redundant rules. + var basicRules []*Rule + if len(b.Rules) != 0 { + // Make priority for rules with RuleIndexTable cause of duplication rules existence with RuleIndexPartition. + // If RuleIndexTable doesn't exist, bundle itself is a independent series of rules for a partition. + for _, rule := range b.Rules { + if rule.Index == RuleIndexTable { + basicRules = append(basicRules, rule) + } + } + if len(basicRules) == 0 { + basicRules = b.Rules + } } + + // extend and reset basic rules for all new ids, the first id should be the group id. + b.ID = GroupID(newIDs[0]) + b.Index = ruleIndex + b.Override = true + newRules := make([]*Rule, 0, len(basicRules)*len(newIDs)) + for i, newID := range newIDs { + // rule.id should be distinguished with each other, otherwise it will be de-duplicated in pd http api. + var ruleID string + if ruleIndex == RuleIndexPartition { + ruleID = "partition_rule_" + strconv.FormatInt(newID, 10) + } else { + if i == 0 { + ruleID = "table_rule_" + strconv.FormatInt(newID, 10) + } else { + ruleID = "partition_rule_" + strconv.FormatInt(newID, 10) + } + } + // Involve all the table level objects. + startKey := hex.EncodeToString(codec.EncodeBytes(nil, tablecodec.GenTablePrefix(newID))) + endKey := hex.EncodeToString(codec.EncodeBytes(nil, tablecodec.GenTablePrefix(newID+1))) + for _, rule := range basicRules { + clone := rule.Clone() + clone.ID = ruleID + clone.GroupID = b.ID + clone.StartKeyHex = startKey + clone.EndKeyHex = endKey + if i == 0 { + clone.Index = RuleIndexTable + } else { + clone.Index = RuleIndexPartition + } + newRules = append(newRules, clone) + } + } + b.Rules = newRules return b } diff --git a/ddl/placement/bundle_test.go b/ddl/placement/bundle_test.go index 0ddde1816f708..f11b0ac0130f2 100644 --- a/ddl/placement/bundle_test.go +++ b/ddl/placement/bundle_test.go @@ -535,7 +535,7 @@ func (s *testBundleSuite) TestString(c *C) { c.Assert(err, IsNil) bundle.Rules = append(rules1, rules2...) - c.Assert(bundle.String(), Equals, `{"group_id":"TiDB_DDL_1","group_index":0,"group_override":false,"rules":[{"group_id":"","id":"","start_key":"","end_key":"","role":"voter","count":3,"label_constraints":[{"key":"zone","op":"in","values":["sh"]}]},{"group_id":"","id":"","start_key":"","end_key":"","role":"voter","count":4,"label_constraints":[{"key":"zone","op":"notIn","values":["sh"]},{"key":"zone","op":"in","values":["bj"]}]}]}`) + c.Assert(bundle.String(), Equals, `{"group_id":"TiDB_DDL_1","group_index":0,"group_override":false,"rules":[{"group_id":"","id":"","start_key":"","end_key":"","role":"voter","count":3,"label_constraints":[{"key":"zone","op":"in","values":["sh"]}],"location_labels":["region","zone","rack","host"],"isolation_level":"region"},{"group_id":"","id":"","start_key":"","end_key":"","role":"voter","count":4,"label_constraints":[{"key":"zone","op":"notIn","values":["sh"]},{"key":"zone","op":"in","values":["bj"]}],"location_labels":["region","zone","rack","host"],"isolation_level":"region"}]}`) c.Assert(failpoint.Enable("github.com/pingcap/tidb/ddl/placement/MockMarshalFailure", `return(true)`), IsNil) defer func() { @@ -580,40 +580,36 @@ func (s *testBundleSuite) TestNewBundleFromOptions(c *C) { name: "sugar syntax: normal case 1", input: &model.PlacementSettings{ PrimaryRegion: "us", + Regions: "us", }, output: []*Rule{ - { - Role: Leader, - Constraints: Constraints{ - { - Key: "region", - Op: In, - Values: []string{"us"}, - }, - }, - Count: 1, - }, - { - Role: Follower, - Constraints: Constraints{}, - Count: 2, - }, + NewRule(Voter, 3, NewConstraintsDirect( + NewConstraintDirect("region", In, "us"), + )), }, }) tests = append(tests, TestCase{ - name: "sugar syntax: invalid followers", + name: "sugar syntax: few followers", input: &model.PlacementSettings{ PrimaryRegion: "us", - Regions: "us,sh,bj", + Regions: "bj,sh,us", + }, + output: []*Rule{ + NewRule(Voter, 1, NewConstraintsDirect( + NewConstraintDirect("region", In, "us"), + )), + NewRule(Follower, 2, NewConstraintsDirect( + NewConstraintDirect("region", In, "bj", "sh"), + )), }, - err: ErrInvalidPlacementOptions, }) tests = append(tests, TestCase{ name: "sugar syntax: wrong schedule prop", input: &model.PlacementSettings{ PrimaryRegion: "us", + Regions: "us", Schedule: "wrong", }, err: ErrInvalidPlacementOptions, @@ -623,8 +619,9 @@ func (s *testBundleSuite) TestNewBundleFromOptions(c *C) { name: "sugar syntax: invalid region name 1", input: &model.PlacementSettings{ PrimaryRegion: ",=,", + Regions: ",=,", }, - err: ErrInvalidConstraintFormat, + err: ErrInvalidPlacementOptions, }) tests = append(tests, TestCase{ @@ -633,68 +630,32 @@ func (s *testBundleSuite) TestNewBundleFromOptions(c *C) { PrimaryRegion: "f", Regions: ",=", }, - err: ErrInvalidConstraintFormat, - }) - - tests = append(tests, TestCase{ - name: "sugar syntax: invalid region name 3", - input: &model.PlacementSettings{ - PrimaryRegion: ",=", - Followers: 5, - Schedule: "majority_in_primary", - }, - err: ErrInvalidConstraintFormat, + err: ErrInvalidPlacementOptions, }) tests = append(tests, TestCase{ name: "sugar syntax: invalid region name 4", input: &model.PlacementSettings{ PrimaryRegion: "", + Regions: "g", }, - output: []*Rule{}, + err: ErrInvalidPlacementOptions, }) tests = append(tests, TestCase{ name: "sugar syntax: normal case 2", input: &model.PlacementSettings{ PrimaryRegion: "us", - Regions: "us,sh", + Regions: "sh,us", Followers: 5, }, output: []*Rule{ - { - Role: Leader, - Constraints: Constraints{ - { - Key: "region", - Op: In, - Values: []string{"us"}, - }, - }, - Count: 1, - }, - { - Role: Follower, - Constraints: Constraints{ - { - Key: "region", - Op: In, - Values: []string{"us"}, - }, - }, - Count: 3, - }, - { - Role: Follower, - Constraints: Constraints{ - { - Key: "region", - Op: In, - Values: []string{"sh"}, - }, - }, - Count: 2, - }, + NewRule(Voter, 3, NewConstraintsDirect( + NewConstraintDirect("region", In, "us"), + )), + NewRule(Follower, 3, NewConstraintsDirect( + NewConstraintDirect("region", In, "sh"), + )), }, }) tests = append(tests, tests[len(tests)-1]) @@ -704,56 +665,18 @@ func (s *testBundleSuite) TestNewBundleFromOptions(c *C) { tests = append(tests, TestCase{ name: "sugar syntax: majority schedule", input: &model.PlacementSettings{ - PrimaryRegion: "us", + PrimaryRegion: "sh", Regions: "bj,sh", Followers: 5, Schedule: "majority_in_primary", }, output: []*Rule{ - { - Role: Leader, - Constraints: Constraints{ - { - Key: "region", - Op: In, - Values: []string{"us"}, - }, - }, - Count: 1, - }, - { - Role: Follower, - Constraints: Constraints{ - { - Key: "region", - Op: In, - Values: []string{"us"}, - }, - }, - Count: 3, - }, - { - Role: Follower, - Constraints: Constraints{ - { - Key: "region", - Op: In, - Values: []string{"bj"}, - }, - }, - Count: 1, - }, - { - Role: Follower, - Constraints: Constraints{ - { - Key: "region", - Op: In, - Values: []string{"sh"}, - }, - }, - Count: 1, - }, + NewRule(Voter, 4, NewConstraintsDirect( + NewConstraintDirect("region", In, "sh"), + )), + NewRule(Follower, 2, NewConstraintsDirect( + NewConstraintDirect("region", In, "bj"), + )), }, }) @@ -764,60 +687,12 @@ func (s *testBundleSuite) TestNewBundleFromOptions(c *C) { Followers: 2, }, output: []*Rule{ - { - Role: Leader, - Constraints: Constraints{ - { - Key: "region", - Op: In, - Values: []string{"us"}, - }, - }, - Count: 1, - }, - { - Role: Follower, - Constraints: Constraints{ - { - Key: "region", - Op: In, - Values: []string{"us"}, - }, - }, - Count: 2, - }, - }, - }) - - tests = append(tests, TestCase{ - name: "direct syntax: normal case 2", - input: &model.PlacementSettings{ - Constraints: "[+region=us]", - Voters: 2, - }, - output: []*Rule{ - { - Role: Leader, - Constraints: Constraints{ - { - Key: "region", - Op: In, - Values: []string{"us"}, - }, - }, - Count: 1, - }, - { - Role: Voter, - Constraints: Constraints{ - { - Key: "region", - Op: In, - Values: []string{"us"}, - }, - }, - Count: 2, - }, + NewRule(Leader, 1, NewConstraintsDirect( + NewConstraintDirect("region", In, "us"), + )), + NewRule(Follower, 2, NewConstraintsDirect( + NewConstraintDirect("region", In, "us"), + )), }, }) @@ -829,39 +704,15 @@ func (s *testBundleSuite) TestNewBundleFromOptions(c *C) { Learners: 2, }, output: []*Rule{ - { - Role: Leader, - Constraints: Constraints{ - { - Key: "region", - Op: In, - Values: []string{"us"}, - }, - }, - Count: 1, - }, - { - Role: Follower, - Constraints: Constraints{ - { - Key: "region", - Op: In, - Values: []string{"us"}, - }, - }, - Count: 2, - }, - { - Role: Learner, - Constraints: Constraints{ - { - Key: "region", - Op: In, - Values: []string{"us"}, - }, - }, - Count: 2, - }, + NewRule(Leader, 1, NewConstraintsDirect( + NewConstraintDirect("region", In, "us"), + )), + NewRule(Follower, 2, NewConstraintsDirect( + NewConstraintDirect("region", In, "us"), + )), + NewRule(Learner, 2, NewConstraintsDirect( + NewConstraintDirect("region", In, "us"), + )), }, }) @@ -875,16 +726,6 @@ func (s *testBundleSuite) TestNewBundleFromOptions(c *C) { err: ErrConflictingConstraints, }) - tests = append(tests, TestCase{ - name: "direct syntax: conflicts 2", - input: &model.PlacementSettings{ - Constraints: "[+region=us]", - VoterConstraints: "[-region=us]", - Voters: 2, - }, - err: ErrConflictingConstraints, - }) - tests = append(tests, TestCase{ name: "direct syntax: conflicts 3", input: &model.PlacementSettings{ @@ -926,16 +767,6 @@ func (s *testBundleSuite) TestNewBundleFromOptions(c *C) { err: ErrInvalidConstraintsFormat, }) - tests = append(tests, TestCase{ - name: "direct syntax: invalid format 3", - input: &model.PlacementSettings{ - Constraints: "[+region=us]", - VoterConstraints: "-region=us]", - Voters: 2, - }, - err: ErrInvalidConstraintsFormat, - }) - tests = append(tests, TestCase{ name: "direct syntax: invalid format 4", input: &model.PlacementSettings{ @@ -958,7 +789,7 @@ func (s *testBundleSuite) TestNewBundleFromOptions(c *C) { }) for _, t := range tests { - bundle, err := NewBundleFromOptions(t.input) + bundle, err := newBundleFromOptions(t.input) comment := Commentf("[%s]\nerr1 %s\nerr2 %s", t.name, err, t.err) if t.err != nil { c.Assert(errors.Is(err, t.err), IsTrue, comment) @@ -969,7 +800,7 @@ func (s *testBundleSuite) TestNewBundleFromOptions(c *C) { } } -func (s *testBundleSuite) TestReset(c *C) { +func (s *testBundleSuite) TestResetBundleWithSingleRule(c *C) { bundle := &Bundle{ ID: GroupID(1), } @@ -978,8 +809,10 @@ func (s *testBundleSuite) TestReset(c *C) { c.Assert(err, IsNil) bundle.Rules = rules - bundle.Reset(3) + bundle.Reset(RuleIndexTable, []int64{3}) c.Assert(bundle.ID, Equals, GroupID(3)) + c.Assert(bundle.Override, Equals, true) + c.Assert(bundle.Index, Equals, RuleIndexTable) c.Assert(bundle.Rules, HasLen, 1) c.Assert(bundle.Rules[0].GroupID, Equals, bundle.ID) @@ -990,6 +823,100 @@ func (s *testBundleSuite) TestReset(c *C) { c.Assert(bundle.Rules[0].EndKeyHex, Equals, endKey) } +func (s *testBundleSuite) TestResetBundleWithMultiRules(c *C) { + // build a bundle with three rules. + bundle, err := NewBundleFromOptions(&model.PlacementSettings{ + LeaderConstraints: `["+zone=bj"]`, + Followers: 2, + FollowerConstraints: `["+zone=hz"]`, + Learners: 1, + LearnerConstraints: `["+zone=cd"]`, + Constraints: `["+disk=ssd"]`, + }) + c.Assert(err, IsNil) + c.Assert(len(bundle.Rules), Equals, 3) + + // test if all the three rules are basic rules even the start key are not set. + bundle.Reset(RuleIndexTable, []int64{1, 2, 3}) + c.Assert(bundle.ID, Equals, GroupID(1)) + c.Assert(bundle.Index, Equals, RuleIndexTable) + c.Assert(bundle.Override, Equals, true) + c.Assert(len(bundle.Rules), Equals, 3*3) + // for id 1. + startKey := hex.EncodeToString(codec.EncodeBytes(nil, tablecodec.GenTablePrefix(1))) + endKey := hex.EncodeToString(codec.EncodeBytes(nil, tablecodec.GenTablePrefix(2))) + c.Assert(bundle.Rules[0].StartKeyHex, Equals, startKey) + c.Assert(bundle.Rules[0].EndKeyHex, Equals, endKey) + c.Assert(bundle.Rules[1].StartKeyHex, Equals, startKey) + c.Assert(bundle.Rules[1].EndKeyHex, Equals, endKey) + c.Assert(bundle.Rules[2].StartKeyHex, Equals, startKey) + c.Assert(bundle.Rules[2].EndKeyHex, Equals, endKey) + // for id 2. + startKey = hex.EncodeToString(codec.EncodeBytes(nil, tablecodec.GenTablePrefix(2))) + endKey = hex.EncodeToString(codec.EncodeBytes(nil, tablecodec.GenTablePrefix(3))) + c.Assert(bundle.Rules[3].StartKeyHex, Equals, startKey) + c.Assert(bundle.Rules[3].EndKeyHex, Equals, endKey) + c.Assert(bundle.Rules[4].StartKeyHex, Equals, startKey) + c.Assert(bundle.Rules[4].EndKeyHex, Equals, endKey) + c.Assert(bundle.Rules[5].StartKeyHex, Equals, startKey) + c.Assert(bundle.Rules[5].EndKeyHex, Equals, endKey) + // for id 3. + startKey = hex.EncodeToString(codec.EncodeBytes(nil, tablecodec.GenTablePrefix(3))) + endKey = hex.EncodeToString(codec.EncodeBytes(nil, tablecodec.GenTablePrefix(4))) + c.Assert(bundle.Rules[6].StartKeyHex, Equals, startKey) + c.Assert(bundle.Rules[6].EndKeyHex, Equals, endKey) + c.Assert(bundle.Rules[7].StartKeyHex, Equals, startKey) + c.Assert(bundle.Rules[7].EndKeyHex, Equals, endKey) + c.Assert(bundle.Rules[8].StartKeyHex, Equals, startKey) + c.Assert(bundle.Rules[8].EndKeyHex, Equals, endKey) + + // test if bundle has redundant rules. + // for now, the bundle has 9 rules, each table id or partition id has the three with them. + // once we reset this bundle for another ids, for example, adding partitions. we should + // extend the basic rules(3 of them) to the new partition id. + bundle.Reset(RuleIndexTable, []int64{1, 3, 4, 5}) + c.Assert(bundle.ID, Equals, GroupID(1)) + c.Assert(bundle.Index, Equals, RuleIndexTable) + c.Assert(bundle.Override, Equals, true) + c.Assert(len(bundle.Rules), Equals, 3*4) + // for id 1. + startKey = hex.EncodeToString(codec.EncodeBytes(nil, tablecodec.GenTablePrefix(1))) + endKey = hex.EncodeToString(codec.EncodeBytes(nil, tablecodec.GenTablePrefix(2))) + c.Assert(bundle.Rules[0].StartKeyHex, Equals, startKey) + c.Assert(bundle.Rules[0].EndKeyHex, Equals, endKey) + c.Assert(bundle.Rules[1].StartKeyHex, Equals, startKey) + c.Assert(bundle.Rules[1].EndKeyHex, Equals, endKey) + c.Assert(bundle.Rules[2].StartKeyHex, Equals, startKey) + c.Assert(bundle.Rules[2].EndKeyHex, Equals, endKey) + // for id 3. + startKey = hex.EncodeToString(codec.EncodeBytes(nil, tablecodec.GenTablePrefix(3))) + endKey = hex.EncodeToString(codec.EncodeBytes(nil, tablecodec.GenTablePrefix(4))) + c.Assert(bundle.Rules[3].StartKeyHex, Equals, startKey) + c.Assert(bundle.Rules[3].EndKeyHex, Equals, endKey) + c.Assert(bundle.Rules[4].StartKeyHex, Equals, startKey) + c.Assert(bundle.Rules[4].EndKeyHex, Equals, endKey) + c.Assert(bundle.Rules[5].StartKeyHex, Equals, startKey) + c.Assert(bundle.Rules[5].EndKeyHex, Equals, endKey) + // for id 4. + startKey = hex.EncodeToString(codec.EncodeBytes(nil, tablecodec.GenTablePrefix(4))) + endKey = hex.EncodeToString(codec.EncodeBytes(nil, tablecodec.GenTablePrefix(5))) + c.Assert(bundle.Rules[6].StartKeyHex, Equals, startKey) + c.Assert(bundle.Rules[6].EndKeyHex, Equals, endKey) + c.Assert(bundle.Rules[7].StartKeyHex, Equals, startKey) + c.Assert(bundle.Rules[7].EndKeyHex, Equals, endKey) + c.Assert(bundle.Rules[8].StartKeyHex, Equals, startKey) + c.Assert(bundle.Rules[8].EndKeyHex, Equals, endKey) + // for id 5. + startKey = hex.EncodeToString(codec.EncodeBytes(nil, tablecodec.GenTablePrefix(5))) + endKey = hex.EncodeToString(codec.EncodeBytes(nil, tablecodec.GenTablePrefix(6))) + c.Assert(bundle.Rules[9].StartKeyHex, Equals, startKey) + c.Assert(bundle.Rules[9].EndKeyHex, Equals, endKey) + c.Assert(bundle.Rules[10].StartKeyHex, Equals, startKey) + c.Assert(bundle.Rules[10].EndKeyHex, Equals, endKey) + c.Assert(bundle.Rules[11].StartKeyHex, Equals, startKey) + c.Assert(bundle.Rules[11].EndKeyHex, Equals, endKey) +} + func (s *testBundleSuite) TestTidy(c *C) { bundle := &Bundle{ ID: GroupID(1), diff --git a/ddl/placement/constraint.go b/ddl/placement/constraint.go index 3263f104dd668..49970eb31570d 100644 --- a/ddl/placement/constraint.go +++ b/ddl/placement/constraint.go @@ -81,10 +81,19 @@ func NewConstraint(label string) (Constraint, error) { r.Key = key r.Op = op - r.Values = []string{val} + r.Values = strings.Split(val, ",") return r, nil } +// NewConstraintDirect will create a Constraint from argument directly. +func NewConstraintDirect(key string, op ConstraintOp, val ...string) Constraint { + return Constraint{ + Key: key, + Op: op, + Values: val, + } +} + // Restore converts a Constraint to a string. func (c *Constraint) Restore() (string, error) { var sb strings.Builder diff --git a/ddl/placement/constraints.go b/ddl/placement/constraints.go index fa4dbcf02a613..87619a8df32df 100644 --- a/ddl/placement/constraints.go +++ b/ddl/placement/constraints.go @@ -26,6 +26,10 @@ type Constraints []Constraint // NewConstraints will check each labels, and build the Constraints. func NewConstraints(labels []string) (Constraints, error) { + if len(labels) == 0 { + return nil, nil + } + constraints := make(Constraints, 0, len(labels)) for _, str := range labels { label, err := NewConstraint(strings.TrimSpace(str)) @@ -52,6 +56,11 @@ func NewConstraintsFromYaml(c []byte) (Constraints, error) { return NewConstraints(constraints) } +// NewConstraintsDirect is a helper for creating new constraints from individual constraint. +func NewConstraintsDirect(c ...Constraint) Constraints { + return c +} + // Restore converts label constraints to a string. func (constraints *Constraints) Restore() (string, error) { var sb strings.Builder diff --git a/ddl/placement/rule.go b/ddl/placement/rule.go index cf94eeba53ca1..216714789aec9 100644 --- a/ddl/placement/rule.go +++ b/ddl/placement/rule.go @@ -50,102 +50,16 @@ type Rule struct { IsolationLevel string `json:"isolation_level,omitempty"` } -type constraintCombineType uint8 - -const ( - listAndList constraintCombineType = 0x0 - dictAndList constraintCombineType = 0x1 -) - -// NewMergeRules constructs []*Rule from a yaml-compatible representation of array or map of constraint1 and constraint2. -// It is quite like NewRules but the common constraint2 can only be a list labels which will be appended to constraint1. -func NewMergeRules(replicas uint64, constr1, constr2 string) ([]*Rule, error) { - var ( - err1, err2, err3 error - combineType constraintCombineType - ) - rules := []*Rule{} - constraintsList1, constraintsList2 := []string{}, []string{} - constraintsDict1 := map[string]int{} - err1, err2 = yaml.UnmarshalStrict([]byte(constr1), &constraintsList1), yaml.UnmarshalStrict([]byte(constr2), &constraintsList2) - if err2 != nil { - // Common constraints can only be a list. - return nil, fmt.Errorf("%w: should be [constraint1, ...] (error %s)", ErrInvalidConstraintsFormat, err2) - } - if err1 != nil { - combineType = 0x01 - err3 = yaml.UnmarshalStrict([]byte(constr1), &constraintsDict1) - if err3 != nil { - return nil, fmt.Errorf("%w: should be [constraint1, ...] (error %s), {constraint1: cnt1, ...} (error %s), or any yaml compatible representation", ErrInvalidConstraintsFormat, err1, err3) - } - } - switch combineType { - case listAndList: - /* - * eg: followerConstraint = ["+zone=sh", "+zone=bj"], constraint = ["+disk=ssd"] - * res: followerConstraint = ["+zone=sh", "+zone=bj", "+disk=ssd"] - */ - if replicas == 0 { - if len(constr1) == 0 { - return rules, nil - } - return rules, fmt.Errorf("%w: should be positive", ErrInvalidConstraintsRelicas) - } - constraintsList1 = append(constraintsList1, constraintsList2...) - labelConstraints, err := NewConstraints(constraintsList1) - if err != nil { - return rules, err - } - rules = append(rules, &Rule{ - Count: int(replicas), - Constraints: labelConstraints, - }) - return rules, nil - case dictAndList: - /* - * eg: followerConstraint = { '+zone=sh, -zone=bj':2, '+zone=sh':1 }, constraint = ['+disk=ssd'] - * res: followerConstraint = { '+zone=sh, -zone=bj, +disk=ssd':2, '+zone=sh, +disk=ssd':1 } - */ - ruleCnt := 0 - for labels, cnt := range constraintsDict1 { - if cnt <= 0 { - return rules, fmt.Errorf("%w: count of labels '%s' should be positive, but got %d", ErrInvalidConstraintsMapcnt, labels, cnt) - } - ruleCnt += cnt - } - if replicas == 0 { - replicas = uint64(ruleCnt) - } - if int(replicas) < ruleCnt { - return rules, fmt.Errorf("%w: should be larger or equal to the number of total replicas, but REPLICAS=%d < total=%d", ErrInvalidConstraintsRelicas, replicas, ruleCnt) - } - for labels, cnt := range constraintsDict1 { - mergeLabels := append(strings.Split(labels, ","), constraintsList2...) - labelConstraints, err := NewConstraints(mergeLabels) - if err != nil { - return rules, err - } - rules = append(rules, &Rule{ - Count: cnt, - Constraints: labelConstraints, - }) - } - remain := int(replicas) - ruleCnt - if remain > 0 { - rules = append(rules, &Rule{ - Count: remain, - }) - } - return rules, nil - } - // empty - return rules, nil -} - // NewRule constructs *Rule from role, count, and constraints. It is here to // consistent the behavior of creating new rules. func NewRule(role PeerRoleType, replicas uint64, cnst Constraints) *Rule { - return &Rule{Role: role, Count: int(replicas), Constraints: cnst} + return &Rule{ + Role: role, + Count: int(replicas), + Constraints: cnst, + LocationLabels: []string{"region", "zone", "rack", "host"}, + IsolationLevel: "region", + } } // NewRules constructs []*Rule from a yaml-compatible representation of @@ -213,3 +127,7 @@ func (r *Rule) Clone() *Rule { *n = *r return n } + +func (r *Rule) String() string { + return fmt.Sprintf("%+v", *r) +} diff --git a/ddl/placement/rule_test.go b/ddl/placement/rule_test.go index 5abebd927a8a2..9432448127a4a 100644 --- a/ddl/placement/rule_test.go +++ b/ddl/placement/rule_test.go @@ -15,7 +15,6 @@ package placement import ( - "encoding/json" "errors" . "github.com/pingcap/check" @@ -34,25 +33,20 @@ func (t *testRuleSuite) TestClone(c *C) { c.Assert(newRule, DeepEquals, &Rule{ID: "121"}) } -func matchRule(r1 *Rule, t2 []*Rule) bool { - for _, r2 := range t2 { - if ok, _ := DeepEquals.Check([]interface{}{r1, r2}, nil); ok { - return true - } - } - return false -} - func matchRules(t1, t2 []*Rule, prefix string, c *C) { - expected, err := json.Marshal(t1) - c.Assert(err, IsNil) - got, err := json.Marshal(t2) - c.Assert(err, IsNil) - comment := Commentf("%s\nexpected %s\nbut got %s", prefix, expected, got) - c.Assert(len(t1), Equals, len(t2), comment) - for _, r1 := range t1 { - comment = Commentf("%s\nerror on %+v\nexpected %s\nbut got %s", prefix, r1, expected, got) - c.Assert(matchRule(r1, t2), IsTrue, comment) + c.Assert(len(t2), Equals, len(t1), Commentf(prefix)) + for i := range t1 { + found := false + for j := range t2 { + ok, _ := DeepEquals.Check([]interface{}{t2[j], t1[i]}, []string{}) + if ok { + found = true + break + } + } + if !found { + c.Errorf("%s\n\ncan not found %d rule\n%+v\n%+v", prefix, i, t1[i], t2) + } } } @@ -71,11 +65,7 @@ func (t *testRuleSuite) TestNewRuleAndNewRules(c *C) { input: "", replicas: 3, output: []*Rule{ - { - Count: 3, - Role: Voter, - Constraints: Constraints{}, - }, + NewRule(Voter, 3, NewConstraintsDirect()), }, }) @@ -86,40 +76,30 @@ func (t *testRuleSuite) TestNewRuleAndNewRules(c *C) { err: ErrInvalidConstraintsRelicas, }) - labels, err := NewConstraints([]string{"+zone=sh", "+zone=sh"}) - c.Assert(err, IsNil) tests = append(tests, TestCase{ name: "normal array constraints", - input: `["+zone=sh", "+zone=sh"]`, + input: `["+zone=sh", "+region=sh"]`, replicas: 3, output: []*Rule{ - { - Count: 3, - Role: Voter, - Constraints: labels, - }, + NewRule(Voter, 3, NewConstraintsDirect( + NewConstraintDirect("zone", In, "sh"), + NewConstraintDirect("region", In, "sh"), + )), }, }) - labels1, err := NewConstraints([]string{"+zone=sh", "-zone=bj"}) - c.Assert(err, IsNil) - labels2, err := NewConstraints([]string{"+zone=sh"}) - c.Assert(err, IsNil) tests = append(tests, TestCase{ name: "normal object constraints", input: `{"+zone=sh,-zone=bj":2, "+zone=sh": 1}`, replicas: 3, output: []*Rule{ - { - Count: 2, - Role: Voter, - Constraints: labels1, - }, - { - Count: 1, - Role: Voter, - Constraints: labels2, - }, + NewRule(Voter, 2, NewConstraintsDirect( + NewConstraintDirect("zone", In, "sh"), + NewConstraintDirect("zone", NotIn, "bj"), + )), + NewRule(Voter, 1, NewConstraintsDirect( + NewConstraintDirect("zone", In, "sh"), + )), }, }) @@ -128,20 +108,14 @@ func (t *testRuleSuite) TestNewRuleAndNewRules(c *C) { input: "{'+zone=sh,-zone=bj':2, '+zone=sh': 1}", replicas: 4, output: []*Rule{ - { - Count: 2, - Role: Voter, - Constraints: labels1, - }, - { - Count: 1, - Role: Voter, - Constraints: labels2, - }, - { - Count: 1, - Role: Voter, - }, + NewRule(Voter, 2, NewConstraintsDirect( + NewConstraintDirect("zone", In, "sh"), + NewConstraintDirect("zone", NotIn, "bj"), + )), + NewRule(Voter, 1, NewConstraintsDirect( + NewConstraintDirect("zone", In, "sh"), + )), + NewRule(Voter, 1, NewConstraintsDirect()), }, }) @@ -149,16 +123,13 @@ func (t *testRuleSuite) TestNewRuleAndNewRules(c *C) { name: "normal object constraints, without count", input: "{'+zone=sh,-zone=bj':2, '+zone=sh': 1}", output: []*Rule{ - { - Count: 2, - Role: Voter, - Constraints: labels1, - }, - { - Count: 1, - Role: Voter, - Constraints: labels2, - }, + NewRule(Voter, 2, NewConstraintsDirect( + NewConstraintDirect("zone", In, "sh"), + NewConstraintDirect("zone", NotIn, "bj"), + )), + NewRule(Voter, 1, NewConstraintsDirect( + NewConstraintDirect("zone", In, "sh"), + )), }, }) diff --git a/ddl/placement_policy.go b/ddl/placement_policy.go index 71f5cdfbc2d5a..f36c71f98f4ba 100644 --- a/ddl/placement_policy.go +++ b/ddl/placement_policy.go @@ -67,30 +67,8 @@ func onCreatePlacementPolicy(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64 } func checkPolicyValidation(info *model.PlacementSettings) error { - checkMergeConstraint := func(replica uint64, constr1, constr2 string) error { - // Constr2 only make sense when replica is set (whether it is in the replica field or included in the constr1) - if replica == 0 && constr1 == "" { - return nil - } - if _, err := placement.NewMergeRules(replica, constr1, constr2); err != nil { - return err - } - return nil - } - if err := checkMergeConstraint(1, info.LeaderConstraints, info.Constraints); err != nil { - return err - } - if err := checkMergeConstraint(info.Followers, info.FollowerConstraints, info.Constraints); err != nil { - return err - } - if err := checkMergeConstraint(info.Voters, info.VoterConstraints, info.Constraints); err != nil { - return err - } - if err := checkMergeConstraint(info.Learners, info.LearnerConstraints, info.Constraints); err != nil { - return err - } - // For constraint labels and default region label, they should be checked by `SHOW LABELS` if necessary when it is applied. - return nil + _, err := placement.NewBundleFromOptions(info) + return err } func getPolicyInfo(t *meta.Meta, policyID int64) (*model.PolicyInfo, error) { @@ -230,45 +208,33 @@ func onAlterPlacementPolicy(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, return ver, errors.Trace(err) } - dbIDs, tblIDs, partIDs, err := getPlacementPolicyDependedObjectsIDs(t, oldPolicy) + dbIDs, partIDs, tblInfos, err := getPlacementPolicyDependedObjectsIDs(t, oldPolicy) if err != nil { return ver, errors.Trace(err) } - if len(dbIDs)+len(tblIDs)+len(partIDs) != 0 { + if len(dbIDs)+len(tblInfos)+len(partIDs) != 0 { // build bundle from new placement policy. bundle, err := placement.NewBundleFromOptions(newPolicyInfo.PlacementSettings) if err != nil { return ver, errors.Trace(err) } - err = bundle.Tidy() - if err != nil { - return ver, errors.Trace(err) - } // Do the http request only when the rules is existed. - bundles := make([]*placement.Bundle, 0, len(tblIDs)+len(partIDs)) - // Reset bundle for tables. - for _, id := range tblIDs { + bundles := make([]*placement.Bundle, 0, len(tblInfos)+len(partIDs)) + // Reset bundle for tables (including the default rule for partition). + for _, tbl := range tblInfos { cp := bundle.Clone() - bundles = append(bundles, cp.Reset(id)) - if len(bundle.Rules) == 0 { - bundle.Index = 0 - bundle.Override = false - } else { - bundle.Index = placement.RuleIndexTable - bundle.Override = true + ids := []int64{tbl.ID} + if tbl.Partition != nil { + for _, pDef := range tbl.Partition.Definitions { + ids = append(ids, pDef.ID) + } } + bundles = append(bundles, cp.Reset(placement.RuleIndexTable, ids)) } // Reset bundle for partitions. for _, id := range partIDs { cp := bundle.Clone() - bundles = append(bundles, cp.Reset(id)) - if len(bundle.Rules) == 0 { - bundle.Index = 0 - bundle.Override = false - } else { - bundle.Index = placement.RuleIndexPartition - bundle.Override = true - } + bundles = append(bundles, cp.Reset(placement.RuleIndexPartition, []int64{id})) } err = infosync.PutRuleBundles(context.TODO(), bundles) if err != nil { @@ -316,15 +282,15 @@ func checkPlacementPolicyNotInUseFromInfoSchema(is infoschema.InfoSchema, policy return nil } -func getPlacementPolicyDependedObjectsIDs(t *meta.Meta, policy *model.PolicyInfo) (dbIDs, tblIDs, partIDs []int64, err error) { +func getPlacementPolicyDependedObjectsIDs(t *meta.Meta, policy *model.PolicyInfo) (dbIDs, partIDs []int64, tblInfos []*model.TableInfo, err error) { schemas, err := t.ListDatabases() if err != nil { return nil, nil, nil, err } // DB ids don't have to set the bundle themselves, but to check the dependency. dbIDs = make([]int64, 0, len(schemas)) - tblIDs = make([]int64, 0, len(schemas)) partIDs = make([]int64, 0, len(schemas)) + tblInfos = make([]*model.TableInfo, 0, len(schemas)) for _, dbInfo := range schemas { if dbInfo.PlacementPolicyRef != nil && dbInfo.PlacementPolicyRef.ID == policy.ID { dbIDs = append(dbIDs, dbInfo.ID) @@ -335,7 +301,7 @@ func getPlacementPolicyDependedObjectsIDs(t *meta.Meta, policy *model.PolicyInfo } for _, tblInfo := range tables { if ref := tblInfo.PlacementPolicyRef; ref != nil && ref.ID == policy.ID { - tblIDs = append(tblIDs, tblInfo.ID) + tblInfos = append(tblInfos, tblInfo) } if tblInfo.Partition != nil { for _, part := range tblInfo.Partition.Definitions { @@ -346,7 +312,7 @@ func getPlacementPolicyDependedObjectsIDs(t *meta.Meta, policy *model.PolicyInfo } } } - return dbIDs, tblIDs, partIDs, nil + return dbIDs, partIDs, tblInfos, nil } func checkPlacementPolicyNotInUseFromMeta(t *meta.Meta, policy *model.PolicyInfo) error { diff --git a/ddl/placement_policy_test.go b/ddl/placement_policy_test.go index beaa795532eee..60d26939a8059 100644 --- a/ddl/placement_policy_test.go +++ b/ddl/placement_policy_test.go @@ -52,8 +52,6 @@ func (s *testDBSuite6) TestPlacementPolicy(c *C) { s.dom.DDL().(ddl.DDLForTest).SetHook(hook) tk.MustExec("create placement policy x " + - "PRIMARY_REGION=\"cn-east-1\" " + - "REGIONS=\"cn-east-1,cn-east-2\" " + "LEARNERS=1 " + "LEARNER_CONSTRAINTS=\"[+region=cn-west-1]\" " + "VOTERS=3 " + @@ -62,8 +60,6 @@ func (s *testDBSuite6) TestPlacementPolicy(c *C) { checkFunc := func(policyInfo *model.PolicyInfo) { c.Assert(policyInfo.ID != 0, Equals, true) c.Assert(policyInfo.Name.L, Equals, "x") - c.Assert(policyInfo.PrimaryRegion, Equals, "cn-east-1") - c.Assert(policyInfo.Regions, Equals, "cn-east-1,cn-east-2") c.Assert(policyInfo.Followers, Equals, uint64(0)) c.Assert(policyInfo.FollowerConstraints, Equals, "") c.Assert(policyInfo.Voters, Equals, uint64(3)) @@ -135,64 +131,36 @@ func testGetPolicyByNameFromIS(c *C, ctx sessionctx.Context, policy string) *mod return po } -func (s *testDBSuite6) TestConstraintCompatibility(c *C) { +func (s *testDBSuite6) TestPlacementValidation(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop placement policy if exists x") cases := []struct { + name string settings string success bool errmsg string }{ - // Dict is not allowed for common constraint. { - settings: "PRIMARY_REGION=\"cn-east-1\" " + - "REGIONS=\"cn-east-1,cn-east-2\" " + - "LEARNERS=1 " + + name: "Dict is not allowed for common constraint", + settings: "LEARNERS=1 " + "LEARNER_CONSTRAINTS=\"[+zone=cn-west-1]\" " + "CONSTRAINTS=\"{'+disk=ssd':2}\"", - errmsg: "invalid label constraints format: should be [constraint1, ...] (error yaml: unmarshal errors:\n line 1: cannot unmarshal !!map into []string)", + errmsg: "invalid label constraints format: 'Constraints' should be [constraint1, ...] or any yaml compatible array representation", }, - // Special constraints may be incompatible with itself. { - settings: "PRIMARY_REGION=\"cn-east-1\" " + - "REGIONS=\"cn-east-1,cn-east-2\" " + - "LEARNERS=1 " + + name: "constraints may be incompatible with itself", + settings: "LEARNERS=1 " + "LEARNER_CONSTRAINTS=\"[+zone=cn-west-1, +zone=cn-west-2]\"", - errmsg: "conflicting label constraints: '+zone=cn-west-2' and '+zone=cn-west-1'", + errmsg: "invalid label constraints format: should be [constraint1, ...] (error conflicting label constraints: '+zone=cn-west-2' and '+zone=cn-west-1'), {constraint1: cnt1, ...} (error yaml: unmarshal errors:\n" + + " line 1: cannot unmarshal !!seq into map[string]int), or any yaml compatible representation: invalid LearnerConstraints", }, { settings: "PRIMARY_REGION=\"cn-east-1\" " + - "REGIONS=\"cn-east-1,cn-east-2\" " + - "LEARNERS=1 " + - "LEARNER_CONSTRAINTS=\"[+zone=cn-west-1, -zone=cn-west-1]\"", - errmsg: "conflicting label constraints: '-zone=cn-west-1' and '+zone=cn-west-1'", - }, - { - settings: "PRIMARY_REGION=\"cn-east-1\" " + - "REGIONS=\"cn-east-1,cn-east-2\" " + - "LEARNERS=1 " + - "LEARNER_CONSTRAINTS=\"[+zone=cn-west-1, +zone=cn-west-1]\"", + "REGIONS=\"cn-east-1,cn-east-2\" ", success: true, }, - // Special constraints may be incompatible with common constraint. - { - settings: "PRIMARY_REGION=\"cn-east-1\" " + - "REGIONS=\"cn-east-1, cn-east-2\" " + - "FOLLOWERS=2 " + - "FOLLOWER_CONSTRAINTS=\"[+zone=cn-east-1]\" " + - "CONSTRAINTS=\"[+zone=cn-east-2]\"", - errmsg: "conflicting label constraints: '+zone=cn-east-2' and '+zone=cn-east-1'", - }, - { - settings: "PRIMARY_REGION=\"cn-east-1\" " + - "REGIONS=\"cn-east-1, cn-east-2\" " + - "FOLLOWERS=2 " + - "FOLLOWER_CONSTRAINTS=\"[+zone=cn-east-1]\" " + - "CONSTRAINTS=\"[+disk=ssd,-zone=cn-east-1]\"", - errmsg: "conflicting label constraints: '-zone=cn-east-1' and '+zone=cn-east-1'", - }, } // test for create @@ -203,23 +171,23 @@ func (s *testDBSuite6) TestConstraintCompatibility(c *C) { tk.MustExec("drop placement policy if exists x") } else { err := tk.ExecToErr(sql) - c.Assert(err, NotNil) - c.Assert(err.Error(), Equals, ca.errmsg) + c.Assert(err, NotNil, Commentf(ca.name)) + c.Assert(err.Error(), Equals, ca.errmsg, Commentf(ca.name)) } } // test for alter - tk.MustExec("create placement policy x regions=\"cn-east1,cn-east\"") + tk.MustExec("create placement policy x primary_region=\"cn-east-1\" regions=\"cn-east-1,cn-east\"") for _, ca := range cases { sql := fmt.Sprintf("%s %s", "alter placement policy x", ca.settings) if ca.success { tk.MustExec(sql) - tk.MustExec("alter placement policy x regions=\"cn-east1,cn-east\"") + tk.MustExec("alter placement policy x primary_region=\"cn-east-1\" regions=\"cn-east-1,cn-east\"") } else { err := tk.ExecToErr(sql) c.Assert(err, NotNil) c.Assert(err.Error(), Equals, ca.errmsg) - tk.MustQuery("show placement where target='POLICY x'").Check(testkit.Rows("POLICY x REGIONS=\"cn-east1,cn-east\"")) + tk.MustQuery("show placement where target='POLICY x'").Check(testkit.Rows("POLICY x PRIMARY_REGION=\"cn-east-1\" REGIONS=\"cn-east-1,cn-east\"")) } } tk.MustExec("drop placement policy x") @@ -229,18 +197,18 @@ func (s *testDBSuite6) TestAlterPlacementPolicy(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop placement policy if exists x") - tk.MustExec("create placement policy x primary_region=\"cn-east-1\" regions=\"cn-east1,cn-east\"") + tk.MustExec("create placement policy x primary_region=\"cn-east-1\" regions=\"cn-east-1,cn-east\"") defer tk.MustExec("drop placement policy if exists x") // test for normal cases - tk.MustExec("alter placement policy x REGIONS=\"bj,sh\"") - tk.MustQuery("show placement where target='POLICY x'").Check(testkit.Rows("POLICY x REGIONS=\"bj,sh\"")) + tk.MustExec("alter placement policy x PRIMARY_REGION=\"bj\" REGIONS=\"bj,sh\"") + tk.MustQuery("show placement where target='POLICY x'").Check(testkit.Rows("POLICY x PRIMARY_REGION=\"bj\" REGIONS=\"bj,sh\"")) tk.MustExec("alter placement policy x " + "PRIMARY_REGION=\"bj\" " + - "REGIONS=\"sh\" " + + "REGIONS=\"bj\" " + "SCHEDULE=\"EVEN\"") - tk.MustQuery("show placement where target='POLICY x'").Check(testkit.Rows("POLICY x PRIMARY_REGION=\"bj\" REGIONS=\"sh\" SCHEDULE=\"EVEN\"")) + tk.MustQuery("show placement where target='POLICY x'").Check(testkit.Rows("POLICY x PRIMARY_REGION=\"bj\" REGIONS=\"bj\" SCHEDULE=\"EVEN\"")) tk.MustExec("alter placement policy x " + "LEADER_CONSTRAINTS=\"[+region=us-east-1]\" " + @@ -273,19 +241,16 @@ func (s *testDBSuite6) TestCreateTableWithPlacementPolicy(c *C) { // Direct placement option: special constraints may be incompatible with common constraint. _, err := tk.Exec("create table t(a int) " + - "PRIMARY_REGION=\"cn-east-1\" " + - "REGIONS=\"cn-east-1, cn-east-2\" " + "FOLLOWERS=2 " + "FOLLOWER_CONSTRAINTS=\"[+zone=cn-east-1]\" " + "CONSTRAINTS=\"[+disk=ssd,-zone=cn-east-1]\"") c.Assert(err, NotNil) - c.Assert(err.Error(), Equals, "conflicting label constraints: '-zone=cn-east-1' and '+zone=cn-east-1'") + c.Assert(err, ErrorMatches, ".*conflicting label constraints.*") tk.MustExec("create table t(a int) " + "PRIMARY_REGION=\"cn-east-1\" " + "REGIONS=\"cn-east-1, cn-east-2\" " + - "FOLLOWERS=2 " + - "CONSTRAINTS=\"[+disk=ssd]\"") + "FOLLOWERS=2 ") tbl := testGetTableByName(c, tk.Se, "test", "t") c.Assert(tbl, NotNil) @@ -301,7 +266,7 @@ func (s *testDBSuite6) TestCreateTableWithPlacementPolicy(c *C) { c.Assert(policySetting.VoterConstraints, Equals, "") c.Assert(policySetting.Learners, Equals, uint64(0)) c.Assert(policySetting.LearnerConstraints, Equals, "") - c.Assert(policySetting.Constraints, Equals, "[+disk=ssd]") + c.Assert(policySetting.Constraints, Equals, "") c.Assert(policySetting.Schedule, Equals, "") } checkFunc(tbl.Meta().DirectPlacementOpts) @@ -321,8 +286,6 @@ func (s *testDBSuite6) TestCreateTableWithPlacementPolicy(c *C) { tk.MustGetErrCode("create table t(a int)"+ "PLACEMENT POLICY=\"x\"", mysql.ErrPlacementPolicyNotExists) tk.MustExec("create placement policy x " + - "PRIMARY_REGION=\"cn-east-1\" " + - "REGIONS=\"cn-east-1, cn-east-2\" " + "FOLLOWERS=2 " + "CONSTRAINTS=\"[+disk=ssd]\" ") tk.MustExec("create table t(a int)" + @@ -335,19 +298,7 @@ func (s *testDBSuite6) TestCreateTableWithPlacementPolicy(c *C) { c.Assert(tbl.Meta().PlacementPolicyRef.ID != 0, Equals, true) tk.MustExec("drop table if exists t") - // Only direct placement options should check the compatibility itself. - _, err = tk.Exec("create table t(a int)" + - "PRIMARY_REGION=\"cn-east-1\" " + - "REGIONS=\"cn-east-1, cn-east-2\" " + - "FOLLOWERS=2 " + - "FOLLOWER_CONSTRAINTS=\"[+zone=cn-east-1]\" " + - "CONSTRAINTS=\"[+disk=ssd, -zone=cn-east-1]\" ") - c.Assert(err, NotNil) - c.Assert(err.Error(), Equals, "conflicting label constraints: '-zone=cn-east-1' and '+zone=cn-east-1'") - tk.MustExec("create table t(a int)" + - "PRIMARY_REGION=\"cn-east-1\" " + - "REGIONS=\"cn-east-1, cn-east-2\" " + "FOLLOWERS=2 " + "CONSTRAINTS=\"[+disk=ssd]\" ") @@ -356,8 +307,8 @@ func (s *testDBSuite6) TestCreateTableWithPlacementPolicy(c *C) { c.Assert(tbl.Meta().DirectPlacementOpts, NotNil) checkFunc = func(policySetting *model.PlacementSettings) { - c.Assert(policySetting.PrimaryRegion, Equals, "cn-east-1") - c.Assert(policySetting.Regions, Equals, "cn-east-1, cn-east-2") + c.Assert(policySetting.PrimaryRegion, Equals, "") + c.Assert(policySetting.Regions, Equals, "") c.Assert(policySetting.Followers, Equals, uint64(2)) c.Assert(policySetting.FollowerConstraints, Equals, "") c.Assert(policySetting.Voters, Equals, uint64(0)) @@ -473,7 +424,7 @@ func testGetPolicyDependency(storage kv.Storage, name string) []int64 { return ids } -func (s *testDBSuite6) TestPolicyCacheAndPolicyDependencyCache(c *C) { +func (s *testDBSuite6) TestPolicyCacheAndPolicyDependency(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop placement policy if exists x") @@ -544,3 +495,165 @@ func (s *testDBSuite6) TestPolicyCacheAndPolicyDependencyCache(c *C) { c.Assert(dependencies, NotNil) c.Assert(len(dependencies), Equals, 0) } + +func (s *testDBSuite6) TestAlterTablePartitionWithPlacementPolicy(c *C) { + tk := testkit.NewTestKit(c, s.store) + defer func() { + tk.MustExec("drop table if exists t1") + tk.MustExec("drop placement policy if exists x") + }() + tk.MustExec("use test") + tk.MustExec("drop table if exists t1") + // Direct placement option: special constraints may be incompatible with common constraint. + tk.MustExec("create table t1 (c int) PARTITION BY RANGE (c) " + + "(PARTITION p0 VALUES LESS THAN (6)," + + "PARTITION p1 VALUES LESS THAN (11)," + + "PARTITION p2 VALUES LESS THAN (16)," + + "PARTITION p3 VALUES LESS THAN (21));") + + tk.MustExec("alter table t1 partition p0 " + + "PRIMARY_REGION=\"cn-east-1\" " + + "REGIONS=\"cn-east-1, cn-east-2\" " + + "FOLLOWERS=2 ") + + tbl := testGetTableByName(c, tk.Se, "test", "t1") + c.Assert(tbl, NotNil) + ptDef := testGetPartitionDefinitionsByName(c, tk.Se, "test", "t1", "p0") + c.Assert(ptDef.PlacementPolicyRef.Name.L, Equals, "") + c.Assert(ptDef.DirectPlacementOpts, NotNil) + + checkFunc := func(policySetting *model.PlacementSettings) { + c.Assert(policySetting.PrimaryRegion, Equals, "cn-east-1") + c.Assert(policySetting.Regions, Equals, "cn-east-1, cn-east-2") + c.Assert(policySetting.Followers, Equals, uint64(2)) + c.Assert(policySetting.FollowerConstraints, Equals, "") + c.Assert(policySetting.Voters, Equals, uint64(0)) + c.Assert(policySetting.VoterConstraints, Equals, "") + c.Assert(policySetting.Learners, Equals, uint64(0)) + c.Assert(policySetting.LearnerConstraints, Equals, "") + c.Assert(policySetting.Constraints, Equals, "") + c.Assert(policySetting.Schedule, Equals, "") + } + checkFunc(ptDef.DirectPlacementOpts) + + //Direct placement option and placement policy can't co-exist. + _, err := tk.Exec("alter table t1 partition p0 " + + "PRIMARY_REGION=\"cn-east-1\" " + + "REGIONS=\"cn-east-1, cn-east-2\" " + + "FOLLOWERS=2 " + + "PLACEMENT POLICY=\"x\"") + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "[ddl:8240]Placement policy 'x' can't co-exist with direct placement options") + + // Only placement policy should check the policy existence. + tk.MustGetErrCode("alter table t1 partition p0 "+ + "PLACEMENT POLICY=\"x\"", mysql.ErrPlacementPolicyNotExists) + tk.MustExec("create placement policy x " + + "FOLLOWERS=2 ") + tk.MustExec("alter table t1 partition p0 " + + "PLACEMENT POLICY=\"x\"") + + ptDef = testGetPartitionDefinitionsByName(c, tk.Se, "test", "t1", "p0") + c.Assert(ptDef, NotNil) + c.Assert(ptDef.PlacementPolicyRef, NotNil) + c.Assert(ptDef.PlacementPolicyRef.Name.L, Equals, "x") + c.Assert(ptDef.PlacementPolicyRef.ID != 0, Equals, true) + + tk.MustExec("alter table t1 partition p0 " + + "PRIMARY_REGION=\"cn-east-1\" " + + "REGIONS=\"cn-east-1, cn-east-2\" " + + "FOLLOWERS=2 ") + + ptDef = testGetPartitionDefinitionsByName(c, tk.Se, "test", "t1", "p0") + c.Assert(ptDef, NotNil) + c.Assert(ptDef.DirectPlacementOpts, NotNil) + + checkFunc = func(policySetting *model.PlacementSettings) { + c.Assert(policySetting.PrimaryRegion, Equals, "cn-east-1") + c.Assert(policySetting.Regions, Equals, "cn-east-1, cn-east-2") + c.Assert(policySetting.Followers, Equals, uint64(2)) + c.Assert(policySetting.FollowerConstraints, Equals, "") + c.Assert(policySetting.Voters, Equals, uint64(0)) + c.Assert(policySetting.VoterConstraints, Equals, "") + c.Assert(policySetting.Learners, Equals, uint64(0)) + c.Assert(policySetting.LearnerConstraints, Equals, "") + c.Assert(policySetting.Constraints, Equals, "") + c.Assert(policySetting.Schedule, Equals, "") + } + checkFunc(ptDef.DirectPlacementOpts) +} + +func testGetPartitionDefinitionsByName(c *C, ctx sessionctx.Context, db string, table string, ptName string) model.PartitionDefinition { + dom := domain.GetDomain(ctx) + // Make sure the table schema is the new schema. + err := dom.Reload() + c.Assert(err, IsNil) + tbl, err := dom.InfoSchema().TableByName(model.NewCIStr(db), model.NewCIStr(table)) + c.Assert(err, IsNil) + c.Assert(tbl, NotNil) + var ptDef model.PartitionDefinition + for _, def := range tbl.Meta().Partition.Definitions { + if ptName == def.Name.L { + ptDef = def + break + } + } + return ptDef +} + +func (s *testDBSuite6) TestPolicyInheritance(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("drop placement policy if exists x") + + // test table inherit database's placement rules. + tk.MustExec("create database mydb constraints=\"[+zone=hangzhou]\"") + tk.MustQuery("show create database mydb").Check(testkit.Rows("mydb CREATE DATABASE `mydb` /*!40100 DEFAULT CHARACTER SET utf8mb4 */ /*T![placement] CONSTRAINTS=\"[+zone=hangzhou]\" */")) + + tk.MustExec("use mydb") + tk.MustExec("create table t(a int)") + tk.MustQuery("show create table t").Check(testkit.Rows("t CREATE TABLE `t` (\n" + + " `a` int(11) DEFAULT NULL\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin /*T![placement] CONSTRAINTS=\"[+zone=hangzhou]\" */")) + tk.MustExec("drop table if exists t") + + tk.MustExec("create table t(a int) constraints=\"[+zone=suzhou]\"") + tk.MustQuery("show create table t").Check(testkit.Rows("t CREATE TABLE `t` (\n" + + " `a` int(11) DEFAULT NULL\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin /*T![placement] CONSTRAINTS=\"[+zone=suzhou]\" */")) + tk.MustExec("drop table if exists t") + + // table will inherit db's placement rules, which is shared by all partition as default one. + tk.MustExec("create table t(a int) partition by range(a) (partition p0 values less than (100), partition p1 values less than (200))") + tk.MustQuery("show create table t").Check(testkit.Rows("t CREATE TABLE `t` (\n" + + " `a` int(11) DEFAULT NULL\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin /*T![placement] CONSTRAINTS=\"[+zone=hangzhou]\" */\n" + + "PARTITION BY RANGE ( `a` ) (\n" + + " PARTITION `p0` VALUES LESS THAN (100),\n" + + " PARTITION `p1` VALUES LESS THAN (200)\n" + + ")")) + tk.MustExec("drop table if exists t") + + // partition's specified placement rules will override the default one. + tk.MustExec("create table t(a int) partition by range(a) (partition p0 values less than (100) constraints=\"[+zone=suzhou]\", partition p1 values less than (200))") + tk.MustQuery("show create table t").Check(testkit.Rows("t CREATE TABLE `t` (\n" + + " `a` int(11) DEFAULT NULL\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin /*T![placement] CONSTRAINTS=\"[+zone=hangzhou]\" */\n" + + "PARTITION BY RANGE ( `a` ) (\n" + + " PARTITION `p0` VALUES LESS THAN (100) /*T![placement] CONSTRAINTS=\"[+zone=suzhou]\" */,\n" + + " PARTITION `p1` VALUES LESS THAN (200)\n" + + ")")) + tk.MustExec("drop table if exists t") + + // test partition override table's placement rules. + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int) CONSTRAINTS=\"[+zone=suzhou]\" partition by range(a) (partition p0 values less than (100) CONSTRAINTS=\"[+zone=changzhou]\", partition p1 values less than (200))") + tk.MustQuery("show create table t").Check(testkit.Rows("t CREATE TABLE `t` (\n" + + " `a` int(11) DEFAULT NULL\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin /*T![placement] CONSTRAINTS=\"[+zone=suzhou]\" */\n" + + "PARTITION BY RANGE ( `a` ) (\n" + + " PARTITION `p0` VALUES LESS THAN (100) /*T![placement] CONSTRAINTS=\"[+zone=changzhou]\" */,\n" + + " PARTITION `p1` VALUES LESS THAN (200)\n" + + ")")) +} diff --git a/ddl/placement_sql_test.go b/ddl/placement_sql_test.go index 3b45e67ffc2e2..2fbba612f383d 100644 --- a/ddl/placement_sql_test.go +++ b/ddl/placement_sql_test.go @@ -731,15 +731,15 @@ func (s *testDBSuite6) TestCreateSchemaWithPlacement(c *C) { tk.Se.GetSessionVars().EnableAlterPlacement = false }() - tk.MustExec(`CREATE SCHEMA SchemaDirectPlacementTest PRIMARY_REGION='nl' REGIONS = "se,nz" FOLLOWERS=3`) - tk.MustQuery("SHOW CREATE SCHEMA schemadirectplacementtest").Check(testkit.Rows("SchemaDirectPlacementTest CREATE DATABASE `SchemaDirectPlacementTest` /*!40100 DEFAULT CHARACTER SET utf8mb4 */ PRIMARY_REGION=\"nl\" REGIONS=\"se,nz\" FOLLOWERS=3")) + tk.MustExec(`CREATE SCHEMA SchemaDirectPlacementTest PRIMARY_REGION='se' REGIONS = "se,nz" FOLLOWERS=3`) + tk.MustQuery("SHOW CREATE SCHEMA schemadirectplacementtest").Check(testkit.Rows("SchemaDirectPlacementTest CREATE DATABASE `SchemaDirectPlacementTest` /*!40100 DEFAULT CHARACTER SET utf8mb4 */ /*T![placement] PRIMARY_REGION=\"se\" REGIONS=\"se,nz\" FOLLOWERS=3 */")) tk.MustExec(`CREATE PLACEMENT POLICY PolicySchemaTest LEADER_CONSTRAINTS = "[+region=nl]" FOLLOWER_CONSTRAINTS="[+region=se]" FOLLOWERS=4 LEARNER_CONSTRAINTS="[+region=be]" LEARNERS=4`) tk.MustExec(`CREATE PLACEMENT POLICY PolicyTableTest LEADER_CONSTRAINTS = "[+region=tl]" FOLLOWER_CONSTRAINTS="[+region=tf]" FOLLOWERS=2 LEARNER_CONSTRAINTS="[+region=tle]" LEARNERS=1`) tk.MustQuery("SHOW PLACEMENT like 'POLICY %PolicySchemaTest%'").Check(testkit.Rows("POLICY PolicySchemaTest LEADER_CONSTRAINTS=\"[+region=nl]\" FOLLOWERS=4 FOLLOWER_CONSTRAINTS=\"[+region=se]\" LEARNERS=4 LEARNER_CONSTRAINTS=\"[+region=be]\"")) tk.MustQuery("SHOW PLACEMENT like 'POLICY %PolicyTableTest%'").Check(testkit.Rows("POLICY PolicyTableTest LEADER_CONSTRAINTS=\"[+region=tl]\" FOLLOWERS=2 FOLLOWER_CONSTRAINTS=\"[+region=tf]\" LEARNERS=1 LEARNER_CONSTRAINTS=\"[+region=tle]\"")) tk.MustExec("CREATE SCHEMA SchemaPolicyPlacementTest PLACEMENT POLICY = `PolicySchemaTest`") - tk.MustQuery("SHOW CREATE SCHEMA SCHEMAPOLICYPLACEMENTTEST").Check(testkit.Rows("SchemaPolicyPlacementTest CREATE DATABASE `SchemaPolicyPlacementTest` /*!40100 DEFAULT CHARACTER SET utf8mb4 */ PLACEMENT POLICY = `PolicySchemaTest`")) + tk.MustQuery("SHOW CREATE SCHEMA SCHEMAPOLICYPLACEMENTTEST").Check(testkit.Rows("SchemaPolicyPlacementTest CREATE DATABASE `SchemaPolicyPlacementTest` /*!40100 DEFAULT CHARACTER SET utf8mb4 */ /*T![placement] PLACEMENT POLICY=`PolicySchemaTest` */")) tk.MustExec(`CREATE TABLE SchemaDirectPlacementTest.UseSchemaDefault (a int unsigned primary key, b varchar(255))`) tk.MustQuery(`SHOW CREATE TABLE SchemaDirectPlacementTest.UseSchemaDefault`).Check(testkit.Rows( @@ -747,21 +747,14 @@ func (s *testDBSuite6) TestCreateSchemaWithPlacement(c *C) { " `a` int(10) unsigned NOT NULL,\n" + " `b` varchar(255) DEFAULT NULL,\n" + " PRIMARY KEY (`a`) /*T![clustered_index] CLUSTERED */\n" + - ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin /*T![placement] PRIMARY_REGION=\"nl\" REGIONS=\"se,nz\" FOLLOWERS=3 */")) - tk.MustExec(`CREATE TABLE SchemaDirectPlacementTest.UseDirectPlacement (a int unsigned primary key, b varchar(255)) PRIMARY_REGION="se"`) + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin /*T![placement] PRIMARY_REGION=\"se\" REGIONS=\"se,nz\" FOLLOWERS=3 */")) + tk.MustExec(`CREATE TABLE SchemaDirectPlacementTest.UseDirectPlacement (a int unsigned primary key, b varchar(255)) PRIMARY_REGION="se" REGIONS="se"`) tk.MustQuery(`SHOW CREATE TABLE SchemaDirectPlacementTest.UseDirectPlacement`).Check(testkit.Rows( "UseDirectPlacement CREATE TABLE `UseDirectPlacement` (\n" + " `a` int(10) unsigned NOT NULL,\n" + " `b` varchar(255) DEFAULT NULL,\n" + " PRIMARY KEY (`a`) /*T![clustered_index] CLUSTERED */\n" + - ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin /*T![placement] PRIMARY_REGION=\"se\" */")) - tk.MustExec(`CREATE TABLE SchemaDirectPlacementTest.UsePolicy (a int unsigned primary key, b varchar(255)) PLACEMENT POLICY = "PolicyTableTest"`) - tk.MustQuery(`SHOW CREATE TABLE SchemaDirectPlacementTest.UsePolicy`).Check(testkit.Rows( - "UsePolicy CREATE TABLE `UsePolicy` (\n" + - " `a` int(10) unsigned NOT NULL,\n" + - " `b` varchar(255) DEFAULT NULL,\n" + - " PRIMARY KEY (`a`) /*T![clustered_index] CLUSTERED */\n" + - ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin /*T![placement] PLACEMENT POLICY=`PolicyTableTest` */")) + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin /*T![placement] PRIMARY_REGION=\"se\" REGIONS=\"se\" */")) tk.MustExec(`CREATE TABLE SchemaPolicyPlacementTest.UseSchemaDefault (a int unsigned primary key, b varchar(255))`) tk.MustQuery(`SHOW CREATE TABLE SchemaPolicyPlacementTest.UseSchemaDefault`).Check(testkit.Rows( @@ -770,13 +763,7 @@ func (s *testDBSuite6) TestCreateSchemaWithPlacement(c *C) { " `b` varchar(255) DEFAULT NULL,\n" + " PRIMARY KEY (`a`) /*T![clustered_index] CLUSTERED */\n" + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin /*T![placement] PLACEMENT POLICY=`PolicySchemaTest` */")) - tk.MustExec(`CREATE TABLE SchemaPolicyPlacementTest.UseDirectPlacement (a int unsigned primary key, b varchar(255)) PRIMARY_REGION="se"`) - tk.MustQuery(`SHOW CREATE TABLE SchemaPolicyPlacementTest.UseDirectPlacement`).Check(testkit.Rows( - "UseDirectPlacement CREATE TABLE `UseDirectPlacement` (\n" + - " `a` int(10) unsigned NOT NULL,\n" + - " `b` varchar(255) DEFAULT NULL,\n" + - " PRIMARY KEY (`a`) /*T![clustered_index] CLUSTERED */\n" + - ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin /*T![placement] PRIMARY_REGION=\"se\" */")) + tk.MustExec(`CREATE TABLE SchemaPolicyPlacementTest.UsePolicy (a int unsigned primary key, b varchar(255)) PLACEMENT POLICY = "PolicyTableTest"`) tk.MustQuery(`SHOW CREATE TABLE SchemaPolicyPlacementTest.UsePolicy`).Check(testkit.Rows( "UsePolicy CREATE TABLE `UsePolicy` (\n" + @@ -791,7 +778,7 @@ func (s *testDBSuite6) TestCreateSchemaWithPlacement(c *C) { c.Assert(ok, IsTrue) c.Assert(db.PlacementPolicyRef, IsNil) c.Assert(db.DirectPlacementOpts, NotNil) - c.Assert(db.DirectPlacementOpts.PrimaryRegion, Matches, "nl") + c.Assert(db.DirectPlacementOpts.PrimaryRegion, Matches, "se") c.Assert(db.DirectPlacementOpts.Regions, Matches, "se,nz") c.Assert(db.DirectPlacementOpts.Followers, Equals, uint64(3)) c.Assert(db.DirectPlacementOpts.Learners, Equals, uint64(0)) diff --git a/ddl/table.go b/ddl/table.go index 507ca5b404451..46db57eba2075 100644 --- a/ddl/table.go +++ b/ddl/table.go @@ -67,56 +67,28 @@ func onCreateTable(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) } return ver, errors.Trace(err) } - // Can not use both a placement policy and direct assignment. If you alter specify both in a CREATE TABLE or ALTER TABLE an error will be returned. - if tbInfo.DirectPlacementOpts != nil && tbInfo.PlacementPolicyRef != nil { - return ver, errors.Trace(ErrPlacementPolicyWithDirectOption.GenWithStackByArgs(tbInfo.PlacementPolicyRef.Name)) - } - var bundle *placement.Bundle - if tbInfo.DirectPlacementOpts != nil { - // check the direct placement option compatibility. - if err := checkPolicyValidation(tbInfo.DirectPlacementOpts); err != nil { - return ver, errors.Trace(err) - } - // build bundle from direct placement options. - bundle, err = placement.NewBundleFromOptions(tbInfo.DirectPlacementOpts) - if err != nil { - return ver, errors.Trace(err) - } + // placement rules meta inheritance. + dbInfo, err := checkSchemaExistAndCancelNotExistJob(t, job) + if err != nil { + return ver, errors.Trace(err) } - if tbInfo.PlacementPolicyRef != nil { - // placement policy reference will override the direct placement options. - po, err := checkPlacementPolicyExistAndCancelNonExistJob(t, job, tbInfo.PlacementPolicyRef.ID) - if err != nil { - return ver, errors.Trace(infoschema.ErrPlacementPolicyNotExists.GenWithStackByArgs(tbInfo.PlacementPolicyRef.Name)) - } - // build bundle from placement policy. - bundle, err = placement.NewBundleFromOptions(po.PlacementSettings) - if err != nil { - return ver, errors.Trace(err) - } + err = inheritPlacementRuleFromDB(tbInfo, dbInfo) + if err != nil { + return ver, errors.Trace(err) } - if bundle == nil { - // get the default bundle from DB or PD. - bundle = infoschema.GetBundle(d.infoCache.GetLatest(), []int64{schemaID}) + + // build table & partition bundles if any. + tableBundle, err := newBundleFromTblInfo(t, job, tbInfo) + if err != nil { + return ver, errors.Trace(err) } - // Do the http request only when the rules is existed. - syncPlacementRules := func() error { - if bundle.Rules == nil { - return nil - } - err = bundle.Tidy() - if err != nil { - return errors.Trace(err) - } - // todo: partitions should use the default table level placement rules or it's specified one. - bundle.Reset(tbInfo.ID) - err = infosync.PutRuleBundles(context.TODO(), []*placement.Bundle{bundle}) - if err != nil { - job.State = model.JobStateCancelled - return errors.Wrapf(err, "failed to notify PD the placement rules") - } - return nil + partitionBundles, err := newBundleFromPartition(t, job, tbInfo.Partition) + if err != nil { + return ver, errors.Trace(err) } + bundles := make([]*placement.Bundle, 0, 1+len(partitionBundles)) + bundles = append(bundles, tableBundle) + bundles = append(bundles, partitionBundles...) ver, err = updateSchemaVersion(t, job) if err != nil { @@ -139,8 +111,10 @@ func onCreateTable(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) } }) // Send the placement bundle to PD. - if err = syncPlacementRules(); err != nil { - return ver, errors.Trace(err) + err = infosync.PutRuleBundles(context.TODO(), bundles) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Wrapf(err, "failed to notify PD the placement rules") } // Finish this job. @@ -152,6 +126,92 @@ func onCreateTable(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) } } +func inheritPlacementRuleFromDB(tbInfo *model.TableInfo, dbInfo *model.DBInfo) error { + if tbInfo.DirectPlacementOpts == nil && tbInfo.PlacementPolicyRef == nil { + if dbInfo.DirectPlacementOpts != nil { + clone := *dbInfo.DirectPlacementOpts + tbInfo.DirectPlacementOpts = &clone + } + if dbInfo.PlacementPolicyRef != nil { + clone := *dbInfo.PlacementPolicyRef + tbInfo.PlacementPolicyRef = &clone + } + } + // Can not use both a placement policy and direct assignment. If you alter specify both in a CREATE TABLE or ALTER TABLE an error will be returned. + if tbInfo.DirectPlacementOpts != nil && tbInfo.PlacementPolicyRef != nil { + return ErrPlacementPolicyWithDirectOption.GenWithStackByArgs(tbInfo.PlacementPolicyRef.Name) + } + return nil +} + +func newBundleFromTblInfo(t *meta.Meta, job *model.Job, tbInfo *model.TableInfo) (*placement.Bundle, error) { + bundle, err := newBundleFromPolicyOrDirectOptions(t, job, tbInfo.PlacementPolicyRef, tbInfo.DirectPlacementOpts) + if err != nil { + return nil, errors.Trace(err) + } + if bundle == nil { + return nil, nil + } + ids := []int64{tbInfo.ID} + // build the default partition rules in the table-level bundle. + if tbInfo.Partition != nil { + for _, pDef := range tbInfo.Partition.Definitions { + ids = append(ids, pDef.ID) + } + } + bundle.Reset(placement.RuleIndexTable, ids) + return bundle, nil +} + +func newBundleFromPartition(t *meta.Meta, job *model.Job, partition *model.PartitionInfo) ([]*placement.Bundle, error) { + if partition == nil { + return nil, nil + } + bundles := make([]*placement.Bundle, 0, len(partition.Definitions)) + // If the partition has the placement rules on their own, build the partition-level bundles additionally. + for _, def := range partition.Definitions { + bundle, err := newBundleFromPolicyOrDirectOptions(t, job, def.PlacementPolicyRef, def.DirectPlacementOpts) + if err != nil { + return nil, errors.Trace(err) + } + if bundle == nil { + continue + } + bundle.Reset(placement.RuleIndexPartition, []int64{def.ID}) + bundles = append(bundles, bundle) + continue + } + return bundles, nil +} + +func newBundleFromPolicyOrDirectOptions(t *meta.Meta, job *model.Job, ref *model.PolicyRefInfo, directOpts *model.PlacementSettings) (*placement.Bundle, error) { + if directOpts != nil { + // build bundle from direct placement options. + bundle, err := placement.NewBundleFromOptions(directOpts) + if err != nil { + job.State = model.JobStateCancelled + return nil, errors.Trace(err) + } + return bundle, nil + } + if ref != nil { + // placement policy reference will override the direct placement options. + po, err := checkPlacementPolicyExistAndCancelNonExistJob(t, job, ref.ID) + if err != nil { + job.State = model.JobStateCancelled + return nil, errors.Trace(infoschema.ErrPlacementPolicyNotExists.GenWithStackByArgs(ref.Name)) + } + // build bundle from placement policy. + bundle, err := placement.NewBundleFromOptions(po.PlacementSettings) + if err != nil { + job.State = model.JobStateCancelled + return nil, errors.Trace(err) + } + return bundle, nil + } + return nil, nil +} + func createTableOrViewWithCheck(t *meta.Meta, job *model.Job, schemaID int64, tbInfo *model.TableInfo) error { err := checkTableInfoValid(tbInfo) if err != nil { @@ -588,7 +648,7 @@ func onTruncateTable(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ erro bundles := make([]*placement.Bundle, 0, len(oldPartitionIDs)+1) if oldBundle, ok := is.BundleByName(placement.GroupID(tableID)); ok { - bundles = append(bundles, oldBundle.Clone().Reset(newTableID)) + bundles = append(bundles, oldBundle.Clone().Reset(placement.RuleIndexTable, []int64{newTableID})) } if pi := tblInfo.GetPartitionInfo(); pi != nil { @@ -600,7 +660,7 @@ func onTruncateTable(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ erro if oldBundle, ok := is.BundleByName(placement.GroupID(oldPartitionIDs[i])); ok && !oldBundle.IsEmpty() { oldIDs = append(oldIDs, oldPartitionIDs[i]) newIDs = append(newIDs, newID) - bundles = append(bundles, oldBundle.Clone().Reset(newID)) + bundles = append(bundles, oldBundle.Clone().Reset(placement.RuleIndexPartition, []int64{newID})) } } job.CtxVars = []interface{}{oldIDs, newIDs} @@ -696,7 +756,7 @@ func onRebaseAutoID(store kv.Storage, t *meta.Meta, job *model.Job, tp autoid.Al if force { err = alloc.ForceRebase(newEnd) } else { - err = alloc.Rebase(newEnd, false) + err = alloc.Rebase(context.Background(), newEnd, false) } if err != nil { job.State = model.JobStateCancelled @@ -1292,6 +1352,45 @@ func onAlterTablePartitionAttributes(t *meta.Meta, job *model.Job) (ver int64, e return ver, nil } +func onAlterTablePartitionOptions(t *meta.Meta, job *model.Job) (ver int64, err error) { + var partitionID int64 + policyRefInfo := &model.PolicyRefInfo{} + placementSettings := &model.PlacementSettings{} + err = job.DecodeArgs(&partitionID, policyRefInfo, placementSettings) + if err != nil { + job.State = model.JobStateCancelled + return 0, errors.Trace(err) + } + tblInfo, err := getTableInfoAndCancelFaultJob(t, job, job.SchemaID) + if err != nil { + return 0, err + } + + ptInfo := tblInfo.GetPartitionInfo() + isFound := false + definitions := ptInfo.Definitions + for i := range definitions { + if partitionID == definitions[i].ID { + definitions[i].DirectPlacementOpts = placementSettings + definitions[i].PlacementPolicyRef = policyRefInfo + isFound = true + break + } + } + if !isFound { + job.State = model.JobStateCancelled + return 0, errors.Trace(table.ErrUnknownPartition.GenWithStackByArgs("drop?", tblInfo.Name.O)) + } + + ver, err = updateVersionAndTableInfo(t, job, tblInfo, true) + if err != nil { + return ver, errors.Trace(err) + } + job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) + + return ver, nil +} + func getOldLabelRules(tblInfo *model.TableInfo, oldSchemaName, oldTableName string) (string, []string, []string, map[string]*label.Rule, error) { tableRuleID := fmt.Sprintf(label.TableIDFormat, label.IDPrefix, oldSchemaName, oldTableName) oldRuleIDs := []string{tableRuleID} diff --git a/docs/design/2021-03-01-pipelined-window-functions.md b/docs/design/2021-03-01-pipelined-window-functions.md new file mode 100644 index 0000000000000..1a1e1024b332d --- /dev/null +++ b/docs/design/2021-03-01-pipelined-window-functions.md @@ -0,0 +1,113 @@ +# Proposal: Pipeline Window Function Execution + +- Author(s): [ichn-hu](https://github.com/ichn-hu) +- Discussion at: https://github.com/pingcap/tidb/pull/23028 +- Tracking issue: https://github.com/pingcap/tidb/pull/23022 + +## Note + +* Row number is often shortened to RN, and RNF for RN function +* Window function is often shortened to WF + +## Abstract + +This document proposes to support executing window functions in a pipelined manner. + +## Background + +The current WF implementation materialized a whole partition before processing it, and if a partition is too large, it will cause TiDB OOM. One particular example is seen in [issue/18444](https://github.com/pingcap/tidb/issues/18444) where the whole table is processed as a single partition in order to get a row number for the paging scenario, while the alternative solution using user variable could significantly decrease the memory usage. + +As the cause is clear, we aim to pipeline the calculation of some of the window function, which means the window function executor will return data as soon as possible before the whole partition is consumed. After this design is implemented, the evaluation of RN WF will not cause the whole partition to be materialized, instead, it can be processed in a pipelined manner in the whole executor pipeline, that’s why we call it pipelining. + +### Review of current implementation + +The current window function implementation is like this (with a focus on processing RN): + +1. Data is sorted by partition key and order by key when feeding to window function. +2. vecGroupChecker is used to split data by the partition key. +3. Data is accumulated in groupRows until a whole partition is read from child executor. +4. Then e.consumeGroupRows will be called, which in turn uses windowProcessor to process the rows. +5. There are current 3 processor types that implement the windowProcessor interface: + 1. aggWindowProcessor, dealt with partition without frame constraint, i.e. the function will be called upon the whole partition, e.g. sum over whole partition, then every row gets the same result on the window function, it is indeed confusing that RN is implemented on aggWindowProcessor, latter we’ll show that it is more natural to be expressed in rowFrameWindowProcessor. + 2. rowFrameWindowProcessor, dealt with partition with ROWS frame constraint, i.e. a fixed length bounding window sliding over rows, each step produced a new value given the rows within the window. Note the window can have unbounded preceding and following. + 3. rangeFrameWindowProcessor, with RANGES frame constraint, i.e. the window is defined by value range, so it can vary (a lot) from row to row. +6. For RN, it only uses `aggWindowProcessor`, as [the MySQL document](https://dev.mysql.com/doc/refman/8.0/en/window-functions-frames.html) pointed out. + +> Standard SQL specifies that window functions that operate on the entire partition should have no frame clause. MySQL permits a frame clause for such functions but ignores it. These functions use the entire partition even if a frame is specified: + +* CUME_DIST() +* DENSE_RANK() +* LAG() +* LEAD() +* NTILE() +* PERCENT_RANK() +* RANK() +* ROW_NUMBER() + +7. In aggWindowProcessor, three functions are implemented: + 1. consumeGroupRows: call agg function’s UpdatePartialResult on all rows within a partition + 2. appendResult2Chunk: call agg function’s AppendFinalResult2Chunk and write result to the result chunk, this function is called repetitively until every row is processed in a partition + 3. resetPartialResult: call agg function’s ResetPartialResult +8. Accordingly, the RN agg function does nothing on UpdatePartialResult, increases the RN counter and append to result on AppendFinalResult2Chunk and resets the counter on ResetPartialResult + +## Proposal + +After carefully examining the source code, we provide the following solution, which is based on unifying windowProcessor, and then pipeline it, so that RN function as well as many other WF currently using sliding windows can be pipelined. + +### Unify windowProcessor + +* For rowFrameWindowProcessor and rangeFrameWindowProcessor does nothing in consumeGroupRows + * And they will call Slide if the WF has implemented slide (i.e. Slide, AppendFinalResult2Chunk), or it will recalculate the result on the whole frame using the traditional aggFunc calculation strategy (i.e. UpdatePartialResult and then AppendFinalResult2Chunk and ResetPartialResult for each row) +* The Slide implementation is by nature pipelinable. + * **The two sides of the sliding window only moves monotonically**. + * However, the current implementation requires the number of rows in the whole partition to be known (or it can’t be pipelined) if the end is unbounded. +* For aggWindowProcessor: + * RN can definitely be pipelined, and it can be implemented in a sliding way (the window is the current row itself) + * Aggregation over the whole partition can’t be pipelined, and it can only be processed after the whole partition is ready. + +However, we could see it as the sliding window is the whole partition for each row + +### How to unify? + +We need to modify the executor build to support this: + +* For row number: the sliding window is of length 1, it slides with current row, **i.e. is a rowFrame start at currentRow and end at currentRow** +* For other agg functions on the whole partition: the sliding window is the whole partition, invariant for each row, **i.e. is a rowFrame start at unbounded preceding and end at unbounded following** + +### Pipelining + +* assume the total number of rows in a partition is N, which we do not know in advance since the data is pipelined + * UpdatePartialResult: append partial rows to the aggregation function, needs to append N rows eventually + * Slide (perhaps this function is better implemented on windowProcessor): must be called before calling AppendFinalResult2Chunk, it returns success or fail if the current rows consumed by UpdatePartialResult is not enough to slide for next row, it will also return an estimated number of rows so that it can be called (useful for unbounded following, we can use -1 to denote that it needs the whole partition, 0 means success immediately, and n if we could know the number of rows needed (for rowFrame) or 1 for rangeFrame since we need to examine row by row + * AppendFinalResult2Chunk: append result for one row, can be called N times, and can only be called after a successful slide + * FinishUpdate: called upon the whole partition has been appended, this is to notify the SlidingWindowAggFunc that the whole partition is consumed, so that for those function that needs the whole partition, it is now time to return success on slide +* We want the movement of the sliding window to be the driver for data fetching on the children executor, so the dataflow logic needs to be modified + * Next() will call windowProcessor’s Slide function to drive it + * Slide function will fetch data from child, and use vecGroupCheck to split it + * Then the data is processed at maximum effort using + * UpdatePartialResult + * Slide or do nothing if it is not SlidingWindowAggFunc + * If Slide returns success or the whole partition is processed + * Obtain the result using AppendFinalResult2Chunk + * Result is then feed back to Next, and returned once the chunk is full + * There could be a background goroutine pulling data + +## Rationale + +This feature will decrease memory consumption for executing window function. + +## Compatibility + +Pipelining won't cause any compatibility issue. + +## Implementation + +All implemented by [PR23022](https://github.com/pingcap/tidb/pull/23022). + +* [x] Create PipelinedWindowExec based on current implementation and modify the windowProcessor interface. +* [x] Change data flow, make Next() pulling data from windowProcessor, and windowProcessor calls fetchChild and process data at maximum effort. +* [x] Modify Slide semantic and add FinishUpdate function on SlidingWindowAggFunc interface, and modify correspondingly on each window function. +* [x] Done pipelining for SlidingWindowAggFunc, add test to make sure it is correct. +* [x] Modify RN to be SlidingWindowAggFunc, and add planner support. +* [x] Add test for RN. +* [x] Benchmark, make sure it has constant memory consumption and no execution time regression. diff --git a/executor/adapter_test.go b/executor/adapter_test.go index ebc7fc60246d0..40ae6e168b1f0 100644 --- a/executor/adapter_test.go +++ b/executor/adapter_test.go @@ -15,24 +15,29 @@ package executor_test import ( + "testing" "time" - . "github.com/pingcap/check" - "github.com/pingcap/tidb/util/testkit" + "github.com/pingcap/tidb/testkit" + "github.com/stretchr/testify/require" ) -func (s *testSuiteP2) TestQueryTime(c *C) { - tk := testkit.NewTestKit(c, s.store) +func TestQueryTime(t *testing.T) { + t.Parallel() + store, clean := testkit.CreateMockStore(t) + defer clean() + + tk := testkit.NewTestKit(t, store) tk.MustExec("use test") - costTime := time.Since(tk.Se.GetSessionVars().StartTime) - c.Assert(costTime < 1*time.Second, IsTrue) + costTime := time.Since(tk.Session().GetSessionVars().StartTime) + require.Less(t, costTime, time.Second) tk.MustExec("drop table if exists t") tk.MustExec("create table t(a int)") tk.MustExec("insert into t values(1), (1), (1), (1), (1)") tk.MustExec("select * from t t1 join t t2 on t1.a = t2.a") - costTime = time.Since(tk.Se.GetSessionVars().StartTime) - c.Assert(costTime < 1*time.Second, IsTrue) + costTime = time.Since(tk.Session().GetSessionVars().StartTime) + require.Less(t, costTime, time.Second) } diff --git a/executor/aggfuncs/aggfunc_test.go b/executor/aggfuncs/aggfunc_test.go index a5cfa6d9af198..b060146244211 100644 --- a/executor/aggfuncs/aggfunc_test.go +++ b/executor/aggfuncs/aggfunc_test.go @@ -43,6 +43,7 @@ import ( "github.com/pingcap/tidb/util/hack" "github.com/pingcap/tidb/util/mock" "github.com/pingcap/tidb/util/set" + "github.com/stretchr/testify/require" "github.com/tikv/client-go/v2/testutils" ) @@ -362,6 +363,110 @@ func buildMultiArgsAggMemTester(funcName string, tps []byte, rt byte, numRows in return pt } +func testMergePartialResult(t *testing.T, p aggTest) { + ctx := mock.NewContext() + srcChk := p.genSrcChk() + iter := chunk.NewIterator4Chunk(srcChk) + + args := []expression.Expression{&expression.Column{RetType: p.dataType, Index: 0}} + if p.funcName == ast.AggFuncGroupConcat { + args = append(args, &expression.Constant{Value: types.NewStringDatum(separator), RetType: types.NewFieldType(mysql.TypeString)}) + } + desc, err := aggregation.NewAggFuncDesc(ctx, p.funcName, args, false) + require.NoError(t, err) + if p.orderBy { + desc.OrderByItems = []*util.ByItems{ + {Expr: args[0], Desc: true}, + } + } + partialDesc, finalDesc := desc.Split([]int{0, 1}) + + // build partial func for partial phase. + partialFunc := aggfuncs.Build(ctx, partialDesc, 0) + partialResult, _ := partialFunc.AllocPartialResult() + + // build final func for final phase. + finalFunc := aggfuncs.Build(ctx, finalDesc, 0) + finalPr, _ := finalFunc.AllocPartialResult() + resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{p.dataType}, 1) + if p.funcName == ast.AggFuncApproxCountDistinct { + resultChk = chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeString)}, 1) + } + if p.funcName == ast.AggFuncJsonArrayagg { + resultChk = chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeJSON)}, 1) + } + + // update partial result. + for row := iter.Begin(); row != iter.End(); row = iter.Next() { + _, err = partialFunc.UpdatePartialResult(ctx, []chunk.Row{row}, partialResult) + require.NoError(t, err) + } + p.messUpChunk(srcChk) + err = partialFunc.AppendFinalResult2Chunk(ctx, partialResult, resultChk) + require.NoError(t, err) + dt := resultChk.GetRow(0).GetDatum(0, p.dataType) + if p.funcName == ast.AggFuncApproxCountDistinct { + dt = resultChk.GetRow(0).GetDatum(0, types.NewFieldType(mysql.TypeString)) + } + if p.funcName == ast.AggFuncJsonArrayagg { + dt = resultChk.GetRow(0).GetDatum(0, types.NewFieldType(mysql.TypeJSON)) + } + result, err := dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[0]) + require.NoError(t, err) + require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[0]) + + _, err = finalFunc.MergePartialResult(ctx, partialResult, finalPr) + require.NoError(t, err) + partialFunc.ResetPartialResult(partialResult) + + srcChk = p.genSrcChk() + iter = chunk.NewIterator4Chunk(srcChk) + iter.Begin() + iter.Next() + for row := iter.Next(); row != iter.End(); row = iter.Next() { + _, err = partialFunc.UpdatePartialResult(ctx, []chunk.Row{row}, partialResult) + require.NoError(t, err) + } + p.messUpChunk(srcChk) + resultChk.Reset() + err = partialFunc.AppendFinalResult2Chunk(ctx, partialResult, resultChk) + require.NoError(t, err) + dt = resultChk.GetRow(0).GetDatum(0, p.dataType) + if p.funcName == ast.AggFuncApproxCountDistinct { + dt = resultChk.GetRow(0).GetDatum(0, types.NewFieldType(mysql.TypeString)) + } + if p.funcName == ast.AggFuncJsonArrayagg { + dt = resultChk.GetRow(0).GetDatum(0, types.NewFieldType(mysql.TypeJSON)) + } + result, err = dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[1]) + require.NoError(t, err) + require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[1]) + _, err = finalFunc.MergePartialResult(ctx, partialResult, finalPr) + require.NoError(t, err) + + if p.funcName == ast.AggFuncApproxCountDistinct { + resultChk = chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeLonglong)}, 1) + } + if p.funcName == ast.AggFuncJsonArrayagg { + resultChk = chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeJSON)}, 1) + } + resultChk.Reset() + err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk) + require.NoError(t, err) + + dt = resultChk.GetRow(0).GetDatum(0, p.dataType) + if p.funcName == ast.AggFuncApproxCountDistinct { + dt = resultChk.GetRow(0).GetDatum(0, types.NewFieldType(mysql.TypeLonglong)) + } + if p.funcName == ast.AggFuncJsonArrayagg { + dt = resultChk.GetRow(0).GetDatum(0, types.NewFieldType(mysql.TypeJSON)) + } + result, err = dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[2]) + require.NoError(t, err) + require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[2]) +} + +// Deprecated: migrating to testMergePartialResult(t *testing.T, p aggTest) func (s *testSuite) testMergePartialResult(c *C, p aggTest) { srcChk := p.genSrcChk() iter := chunk.NewIterator4Chunk(srcChk) @@ -615,6 +720,96 @@ func getDataGenFunc(ft *types.FieldType) func(i int) types.Datum { return nil } +func testAggFunc(t *testing.T, p aggTest) { + srcChk := p.genSrcChk() + ctx := mock.NewContext() + + args := []expression.Expression{&expression.Column{RetType: p.dataType, Index: 0}} + if p.funcName == ast.AggFuncGroupConcat { + args = append(args, &expression.Constant{Value: types.NewStringDatum(separator), RetType: types.NewFieldType(mysql.TypeString)}) + } + if p.funcName == ast.AggFuncApproxPercentile { + args = append(args, &expression.Constant{Value: types.NewIntDatum(50), RetType: types.NewFieldType(mysql.TypeLong)}) + } + desc, err := aggregation.NewAggFuncDesc(ctx, p.funcName, args, false) + require.NoError(t, err) + if p.orderBy { + desc.OrderByItems = []*util.ByItems{ + {Expr: args[0], Desc: true}, + } + } + finalFunc := aggfuncs.Build(ctx, desc, 0) + finalPr, _ := finalFunc.AllocPartialResult() + resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{desc.RetTp}, 1) + + iter := chunk.NewIterator4Chunk(srcChk) + for row := iter.Begin(); row != iter.End(); row = iter.Next() { + _, err = finalFunc.UpdatePartialResult(ctx, []chunk.Row{row}, finalPr) + require.NoError(t, err) + } + p.messUpChunk(srcChk) + err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk) + require.NoError(t, err) + dt := resultChk.GetRow(0).GetDatum(0, desc.RetTp) + result, err := dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[1]) + require.NoError(t, err) + require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[1]) + + // test the empty input + resultChk.Reset() + finalFunc.ResetPartialResult(finalPr) + err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk) + require.NoError(t, err) + dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp) + result, err = dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[0]) + require.NoError(t, err) + require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[0]) + + // test the agg func with distinct + desc, err = aggregation.NewAggFuncDesc(ctx, p.funcName, args, true) + require.NoError(t, err) + if p.orderBy { + desc.OrderByItems = []*util.ByItems{ + {Expr: args[0], Desc: true}, + } + } + finalFunc = aggfuncs.Build(ctx, desc, 0) + finalPr, _ = finalFunc.AllocPartialResult() + + resultChk.Reset() + srcChk = p.genSrcChk() + iter = chunk.NewIterator4Chunk(srcChk) + for row := iter.Begin(); row != iter.End(); row = iter.Next() { + _, err = finalFunc.UpdatePartialResult(ctx, []chunk.Row{row}, finalPr) + require.NoError(t, err) + } + p.messUpChunk(srcChk) + srcChk = p.genSrcChk() + iter = chunk.NewIterator4Chunk(srcChk) + for row := iter.Begin(); row != iter.End(); row = iter.Next() { + _, err = finalFunc.UpdatePartialResult(ctx, []chunk.Row{row}, finalPr) + require.NoError(t, err) + } + p.messUpChunk(srcChk) + err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk) + require.NoError(t, err) + dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp) + result, err = dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[1]) + require.NoError(t, err) + require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[1]) + + // test the empty input + resultChk.Reset() + finalFunc.ResetPartialResult(finalPr) + err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk) + require.NoError(t, err) + dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp) + result, err = dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[0]) + require.NoError(t, err) + require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[0]) +} + +// Deprecated: migrating to func testAggFunc(t *testing.T, p aggTest) func (s *testSuite) testAggFunc(c *C, p aggTest) { srcChk := p.genSrcChk() @@ -748,6 +943,38 @@ func (s *testSuite) testAggFuncWithoutDistinct(c *C, p aggTest) { c.Assert(result, Equals, 0, Commentf("%v != %v", dt.String(), p.results[0])) } +func testAggMemFunc(t *testing.T, p aggMemTest) { + srcChk := p.aggTest.genSrcChk() + ctx := mock.NewContext() + + args := []expression.Expression{&expression.Column{RetType: p.aggTest.dataType, Index: 0}} + if p.aggTest.funcName == ast.AggFuncGroupConcat { + args = append(args, &expression.Constant{Value: types.NewStringDatum(separator), RetType: types.NewFieldType(mysql.TypeString)}) + } + desc, err := aggregation.NewAggFuncDesc(ctx, p.aggTest.funcName, args, p.isDistinct) + require.NoError(t, err) + if p.aggTest.orderBy { + desc.OrderByItems = []*util.ByItems{ + {Expr: args[0], Desc: true}, + } + } + finalFunc := aggfuncs.Build(ctx, desc, 0) + finalPr, memDelta := finalFunc.AllocPartialResult() + require.Equal(t, p.allocMemDelta, memDelta) + + updateMemDeltas, err := p.updateMemDeltaGens(srcChk, p.aggTest.dataType) + require.NoError(t, err) + iter := chunk.NewIterator4Chunk(srcChk) + i := 0 + for row := iter.Begin(); row != iter.End(); row = iter.Next() { + memDelta, err := finalFunc.UpdatePartialResult(ctx, []chunk.Row{row}, finalPr) + require.NoError(t, err) + require.Equal(t, updateMemDeltas[i], memDelta) + i++ + } +} + +// Deprecated: migrating to testAggMemFunc(t *testing.T, p aggMemTest) func (s *testSuite) testAggMemFunc(c *C, p aggMemTest) { srcChk := p.aggTest.genSrcChk() diff --git a/executor/aggfuncs/func_cume_dist_test.go b/executor/aggfuncs/func_cume_dist_test.go index d5a6c5dc2df85..7a0180b272fd1 100644 --- a/executor/aggfuncs/func_cume_dist_test.go +++ b/executor/aggfuncs/func_cume_dist_test.go @@ -15,13 +15,16 @@ package aggfuncs_test import ( - . "github.com/pingcap/check" + "testing" + "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/executor/aggfuncs" ) -func (s *testSuite) TestMemCumeDist(c *C) { +func TestMemCumeDist(t *testing.T) { + t.Parallel() + tests := []windowMemTest{ buildWindowMemTester(ast.WindowFuncCumeDist, mysql.TypeLonglong, 0, 1, 1, aggfuncs.DefPartialResult4CumeDistSize, rowMemDeltaGens), @@ -31,6 +34,6 @@ func (s *testSuite) TestMemCumeDist(c *C) { aggfuncs.DefPartialResult4CumeDistSize, rowMemDeltaGens), } for _, test := range tests { - s.testWindowAggMemFunc(c, test) + testWindowAggMemFunc(t, test) } } diff --git a/executor/aggfuncs/func_lead_lag_test.go b/executor/aggfuncs/func_lead_lag_test.go index 9dc49c19208d6..7459b8082dca8 100644 --- a/executor/aggfuncs/func_lead_lag_test.go +++ b/executor/aggfuncs/func_lead_lag_test.go @@ -15,7 +15,8 @@ package aggfuncs_test import ( - . "github.com/pingcap/check" + "testing" + "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/executor/aggfuncs" @@ -23,7 +24,9 @@ import ( "github.com/pingcap/tidb/types" ) -func (s *testSuite) TestLeadLag(c *C) { +func TestLeadLag(t *testing.T) { + t.Parallel() + zero := expression.NewZero() one := expression.NewOne() two := &expression.Constant{ @@ -111,12 +114,14 @@ func (s *testSuite) TestLeadLag(c *C) { []expression.Expression{million, defaultArg}, 0, numRows, 0, 1, 2), } for _, test := range tests { - s.testWindowFunc(c, test) + testWindowFunc(t, test) } } -func (s *testSuite) TestMemLeadLag(c *C) { +func TestMemLeadLag(t *testing.T) { + t.Parallel() + zero := expression.NewZero() one := expression.NewOne() two := &expression.Constant{ @@ -160,7 +165,7 @@ func (s *testSuite) TestMemLeadLag(c *C) { } for _, test := range tests { - s.testWindowAggMemFunc(c, test) + testWindowAggMemFunc(t, test) } } diff --git a/executor/aggfuncs/func_ntile_test.go b/executor/aggfuncs/func_ntile_test.go index da820a9d6c3bc..fdfcd9d08291f 100644 --- a/executor/aggfuncs/func_ntile_test.go +++ b/executor/aggfuncs/func_ntile_test.go @@ -15,14 +15,16 @@ package aggfuncs_test import ( - . "github.com/pingcap/check" + "testing" + "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" - "github.com/pingcap/tidb/executor/aggfuncs" ) -func (s *testSuite) TestMemNtile(c *C) { +func TestMemNtile(t *testing.T) { + t.Parallel() + tests := []windowMemTest{ buildWindowMemTester(ast.WindowFuncNtile, mysql.TypeLonglong, 1, 1, 1, aggfuncs.DefPartialResult4Ntile, defaultUpdateMemDeltaGens), @@ -32,6 +34,6 @@ func (s *testSuite) TestMemNtile(c *C) { aggfuncs.DefPartialResult4Ntile, defaultUpdateMemDeltaGens), } for _, test := range tests { - s.testWindowAggMemFunc(c, test) + testWindowAggMemFunc(t, test) } } diff --git a/executor/aggfuncs/func_percent_rank_test.go b/executor/aggfuncs/func_percent_rank_test.go index ce174a2342be8..5e04218e19ddd 100644 --- a/executor/aggfuncs/func_percent_rank_test.go +++ b/executor/aggfuncs/func_percent_rank_test.go @@ -15,13 +15,16 @@ package aggfuncs_test import ( - . "github.com/pingcap/check" + "testing" + "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/executor/aggfuncs" ) -func (s *testSuite) TestMemPercentRank(c *C) { +func TestMemPercentRank(t *testing.T) { + t.Parallel() + tests := []windowMemTest{ buildWindowMemTester(ast.WindowFuncPercentRank, mysql.TypeLonglong, 0, 1, 1, aggfuncs.DefPartialResult4RankSize, rowMemDeltaGens), @@ -31,6 +34,6 @@ func (s *testSuite) TestMemPercentRank(c *C) { aggfuncs.DefPartialResult4RankSize, rowMemDeltaGens), } for _, test := range tests { - s.testWindowAggMemFunc(c, test) + testWindowAggMemFunc(t, test) } } diff --git a/executor/aggfuncs/func_rank_test.go b/executor/aggfuncs/func_rank_test.go index eeb5f13724dcf..211ade16583bd 100644 --- a/executor/aggfuncs/func_rank_test.go +++ b/executor/aggfuncs/func_rank_test.go @@ -15,13 +15,16 @@ package aggfuncs_test import ( - . "github.com/pingcap/check" + "testing" + "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/executor/aggfuncs" ) -func (s *testSuite) TestMemRank(c *C) { +func TestMemRank(t *testing.T) { + t.Parallel() + tests := []windowMemTest{ buildWindowMemTester(ast.WindowFuncRank, mysql.TypeLonglong, 0, 1, 1, aggfuncs.DefPartialResult4RankSize, rowMemDeltaGens), @@ -31,6 +34,6 @@ func (s *testSuite) TestMemRank(c *C) { aggfuncs.DefPartialResult4RankSize, rowMemDeltaGens), } for _, test := range tests { - s.testWindowAggMemFunc(c, test) + testWindowAggMemFunc(t, test) } } diff --git a/executor/aggfuncs/func_value_test.go b/executor/aggfuncs/func_value_test.go index 97a03bee0fd03..e83b0e08ee677 100644 --- a/executor/aggfuncs/func_value_test.go +++ b/executor/aggfuncs/func_value_test.go @@ -15,7 +15,8 @@ package aggfuncs_test import ( - . "github.com/pingcap/check" + "testing" + "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/executor/aggfuncs" @@ -59,7 +60,9 @@ func nthValueEvaluateRowUpdateMemDeltaGens(nth int) updateMemDeltaGens { } } -func (s *testSuite) TestMemValue(c *C) { +func TestMemValue(t *testing.T) { + t.Parallel() + firstMemDeltaGens := nthValueEvaluateRowUpdateMemDeltaGens(1) secondMemDeltaGens := nthValueEvaluateRowUpdateMemDeltaGens(2) fifthMemDeltaGens := nthValueEvaluateRowUpdateMemDeltaGens(5) @@ -96,6 +99,6 @@ func (s *testSuite) TestMemValue(c *C) { aggfuncs.DefPartialResult4NthValueSize+aggfuncs.DefValue4StringSize, fifthMemDeltaGens), } for _, test := range tests { - s.testWindowAggMemFunc(c, test) + testWindowAggMemFunc(t, test) } } diff --git a/executor/aggfuncs/func_varpop_test.go b/executor/aggfuncs/func_varpop_test.go index 777245a74a1e9..d1b70e4a7ac69 100644 --- a/executor/aggfuncs/func_varpop_test.go +++ b/executor/aggfuncs/func_varpop_test.go @@ -15,7 +15,9 @@ package aggfuncs_test import ( - . "github.com/pingcap/check" + "fmt" + "testing" + "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/executor/aggfuncs" @@ -23,32 +25,42 @@ import ( "github.com/pingcap/tidb/util/set" ) -func (s *testSuite) TestMergePartialResult4Varpop(c *C) { +func TestMergePartialResult4Varpop(t *testing.T) { + t.Parallel() + tests := []aggTest{ buildAggTester(ast.AggFuncVarPop, mysql.TypeDouble, 5, types.NewFloat64Datum(float64(2)), types.NewFloat64Datum(float64(2)/float64(3)), types.NewFloat64Datum(float64(59)/float64(8)-float64(19*19)/float64(8*8))), } for _, test := range tests { - s.testMergePartialResult(c, test) + testMergePartialResult(t, test) } } -func (s *testSuite) TestVarpop(c *C) { +func TestVarpop(t *testing.T) { + t.Parallel() + tests := []aggTest{ buildAggTester(ast.AggFuncVarPop, mysql.TypeDouble, 5, nil, types.NewFloat64Datum(float64(2))), } for _, test := range tests { - s.testAggFunc(c, test) + testAggFunc(t, test) } } -func (s *testSuite) TestMemVarpop(c *C) { +func TestMemVarpop(t *testing.T) { + t.Parallel() + tests := []aggMemTest{ buildAggMemTester(ast.AggFuncVarPop, mysql.TypeDouble, 5, aggfuncs.DefPartialResult4VarPopFloat64Size, defaultUpdateMemDeltaGens, false), buildAggMemTester(ast.AggFuncVarPop, mysql.TypeDouble, 5, aggfuncs.DefPartialResult4VarPopDistinctFloat64Size+set.DefFloat64SetBucketMemoryUsage, distinctUpdateMemDeltaGens, true), } - for _, test := range tests { - s.testAggMemFunc(c, test) + for n, test := range tests { + test := test + t.Run(fmt.Sprintf("%s_%d", test.aggTest.funcName, n), func(t *testing.T) { + t.Parallel() + testAggMemFunc(t, test) + }) } } diff --git a/executor/aggfuncs/row_number_test.go b/executor/aggfuncs/row_number_test.go index 3a76be76be719..d4c001fb10a5e 100644 --- a/executor/aggfuncs/row_number_test.go +++ b/executor/aggfuncs/row_number_test.go @@ -15,18 +15,21 @@ package aggfuncs_test import ( - . "github.com/pingcap/check" + "testing" + "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/executor/aggfuncs" ) -func (s *testSuite) TestMemRowNumber(c *C) { +func TestMemRowNumber(t *testing.T) { + t.Parallel() + tests := []windowMemTest{ buildWindowMemTester(ast.WindowFuncRowNumber, mysql.TypeLonglong, 0, 0, 4, aggfuncs.DefPartialResult4RowNumberSize, defaultUpdateMemDeltaGens), } for _, test := range tests { - s.testWindowAggMemFunc(c, test) + testWindowAggMemFunc(t, test) } } diff --git a/executor/aggfuncs/window_func_test.go b/executor/aggfuncs/window_func_test.go index a6d2a9e75d333..8e2804c68025f 100644 --- a/executor/aggfuncs/window_func_test.go +++ b/executor/aggfuncs/window_func_test.go @@ -15,9 +15,9 @@ package aggfuncs_test import ( + "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" @@ -27,6 +27,9 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/types/json" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/mock" + + "github.com/stretchr/testify/require" ) type windowTest struct { @@ -54,52 +57,54 @@ type windowMemTest struct { updateMemDeltaGens updateMemDeltaGens } -func (s *testSuite) testWindowFunc(c *C, p windowTest) { +func testWindowFunc(t *testing.T, p windowTest) { srcChk := p.genSrcChk() + ctx := mock.NewContext() - desc, err := aggregation.NewAggFuncDesc(s.ctx, p.funcName, p.args, false) - c.Assert(err, IsNil) - finalFunc := aggfuncs.BuildWindowFunctions(s.ctx, desc, 0, p.orderByCols) + desc, err := aggregation.NewAggFuncDesc(ctx, p.funcName, p.args, false) + require.NoError(t, err) + finalFunc := aggfuncs.BuildWindowFunctions(ctx, desc, 0, p.orderByCols) finalPr, _ := finalFunc.AllocPartialResult() resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{desc.RetTp}, 1) iter := chunk.NewIterator4Chunk(srcChk) for row := iter.Begin(); row != iter.End(); row = iter.Next() { - _, err = finalFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, finalPr) - c.Assert(err, IsNil) + _, err = finalFunc.UpdatePartialResult(ctx, []chunk.Row{row}, finalPr) + require.NoError(t, err) } - c.Assert(p.numRows, Equals, len(p.results)) + require.Len(t, p.results, p.numRows) for i := 0; i < p.numRows; i++ { - err = finalFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk) - c.Assert(err, IsNil) + err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk) + require.NoError(t, err) dt := resultChk.GetRow(0).GetDatum(0, desc.RetTp) - result, err := dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[i]) - c.Assert(err, IsNil) - c.Assert(result, Equals, 0) + result, err := dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[i]) + require.NoError(t, err) + require.Equal(t, 0, result) resultChk.Reset() } finalFunc.ResetPartialResult(finalPr) } -func (s *testSuite) testWindowAggMemFunc(c *C, p windowMemTest) { +func testWindowAggMemFunc(t *testing.T, p windowMemTest) { srcChk := p.windowTest.genSrcChk() + ctx := mock.NewContext() - desc, err := aggregation.NewAggFuncDesc(s.ctx, p.windowTest.funcName, p.windowTest.args, false) - c.Assert(err, IsNil) - finalFunc := aggfuncs.BuildWindowFunctions(s.ctx, desc, 0, p.windowTest.orderByCols) + desc, err := aggregation.NewAggFuncDesc(ctx, p.windowTest.funcName, p.windowTest.args, false) + require.NoError(t, err) + finalFunc := aggfuncs.BuildWindowFunctions(ctx, desc, 0, p.windowTest.orderByCols) finalPr, memDelta := finalFunc.AllocPartialResult() - c.Assert(memDelta, Equals, p.allocMemDelta) + require.Equal(t, p.allocMemDelta, memDelta) updateMemDeltas, err := p.updateMemDeltaGens(srcChk, p.windowTest.dataType) - c.Assert(err, IsNil) + require.NoError(t, err) i := 0 iter := chunk.NewIterator4Chunk(srcChk) for row := iter.Begin(); row != iter.End(); row = iter.Next() { - memDelta, err = finalFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, finalPr) - c.Assert(err, IsNil) - c.Assert(memDelta, Equals, updateMemDeltas[i]) + memDelta, err = finalFunc.UpdatePartialResult(ctx, []chunk.Row{row}, finalPr) + require.NoError(t, err) + require.Equal(t, updateMemDeltas[i], memDelta) i++ } } @@ -166,7 +171,9 @@ func buildWindowMemTesterWithArgs(funcName string, tp byte, args []expression.Ex return pt } -func (s *testSuite) TestWindowFunctions(c *C) { +func TestWindowFunctions(t *testing.T) { + t.Parallel() + tests := []windowTest{ buildWindowTester(ast.WindowFuncCumeDist, mysql.TypeLonglong, 0, 1, 1, 1), buildWindowTester(ast.WindowFuncCumeDist, mysql.TypeLonglong, 0, 0, 2, 1, 1), @@ -203,6 +210,6 @@ func (s *testSuite) TestWindowFunctions(c *C) { buildWindowTester(ast.WindowFuncRowNumber, mysql.TypeLonglong, 0, 0, 4, 1, 2, 3, 4), } for _, test := range tests { - s.testWindowFunc(c, test) + testWindowFunc(t, test) } } diff --git a/executor/ddl_test.go b/executor/ddl_test.go index eac7d73c29bb1..714f47ac7ac35 100644 --- a/executor/ddl_test.go +++ b/executor/ddl_test.go @@ -903,7 +903,7 @@ func (s *testSuite8) TestShardRowIDBits(c *C) { c.Assert(err, IsNil) maxID := 1<<(64-15-1) - 1 alloc := tbl.Allocators(tk.Se).Get(autoid.RowIDAllocType) - err = alloc.Rebase(int64(maxID)-1, false) + err = alloc.Rebase(context.Background(), int64(maxID)-1, false) c.Assert(err, IsNil) tk.MustExec("insert into t1 values(1)") diff --git a/executor/executor_test.go b/executor/executor_test.go index 510435fef2ef9..57dc87cb7850c 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -75,7 +75,6 @@ import ( "github.com/pingcap/tidb/util/mock" "github.com/pingcap/tidb/util/rowcodec" "github.com/pingcap/tidb/util/testkit" - "github.com/pingcap/tidb/util/testleak" "github.com/pingcap/tidb/util/testutil" "github.com/pingcap/tidb/util/timeutil" "github.com/pingcap/tipb/go-tipb" @@ -89,25 +88,8 @@ import ( func TestT(t *testing.T) { CustomVerboseFlag = true *CustomParallelSuiteFlag = true - logLevel := os.Getenv("log_level") - err := logutil.InitLogger(logutil.NewLogConfig(logLevel, logutil.DefaultLogFormat, "", logutil.EmptyFileLogConfig, false)) - if err != nil { - t.Fatal(err) - } - autoid.SetStep(5000) - config.UpdateGlobal(func(conf *config.Config) { - conf.Log.SlowThreshold = 30000 // 30s - conf.TiKVClient.AsyncCommit.SafeWindow = 0 - conf.TiKVClient.AsyncCommit.AllowedClockDrift = 0 - }) - tikv.EnableFailpoints() - tmpDir := config.GetGlobalConfig().TempStoragePath - _ = os.RemoveAll(tmpDir) // clean the uncleared temp file during the last run. - _ = os.MkdirAll(tmpDir, 0755) - testleak.BeforeTest() TestingT(t) - testleak.AfterTestT(t)() } var _ = Suite(&testSuite{&baseTestSuite{}}) diff --git a/executor/insert.go b/executor/insert.go index f5b443387dd75..41b1ae8a738a4 100644 --- a/executor/insert.go +++ b/executor/insert.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/meta/autoid" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/tablecodec" "github.com/pingcap/tidb/types" @@ -302,6 +303,10 @@ func (e *InsertExec) batchUpdateDupRows(ctx context.Context, newRows [][]types.D // Next implements the Executor Next interface. func (e *InsertExec) Next(ctx context.Context, req *chunk.Chunk) error { req.Reset() + if e.collectRuntimeStatsEnabled() { + ctx = context.WithValue(ctx, autoid.AllocatorRuntimeStatsCtxKey, e.stats.AllocatorRuntimeStats) + } + if len(e.children) > 0 && e.children[0] != nil { return insertRowsFromSelect(ctx, e) } diff --git a/executor/insert_common.go b/executor/insert_common.go index 72ed1f51b584f..fb5ebc6c910e3 100644 --- a/executor/insert_common.go +++ b/executor/insert_common.go @@ -711,7 +711,7 @@ func (e *InsertValues) lazyAdjustAutoIncrementDatum(ctx context.Context, rows [] } // Use the value if it's not null and not 0. if recordID != 0 { - err = e.Table.Allocators(e.ctx).Get(autoid.RowIDAllocType).Rebase(recordID, true) + err = e.Table.Allocators(e.ctx).Get(autoid.RowIDAllocType).Rebase(ctx, recordID, true) if err != nil { return nil, err } @@ -801,7 +801,7 @@ func (e *InsertValues) adjustAutoIncrementDatum(ctx context.Context, d types.Dat } // Use the value if it's not null and not 0. if recordID != 0 { - err = e.Table.Allocators(e.ctx).Get(autoid.RowIDAllocType).Rebase(recordID, true) + err = e.Table.Allocators(e.ctx).Get(autoid.RowIDAllocType).Rebase(ctx, recordID, true) if err != nil { return types.Datum{}, err } @@ -877,7 +877,7 @@ func (e *InsertValues) adjustAutoRandomDatum(ctx context.Context, d types.Datum, if !e.ctx.GetSessionVars().AllowAutoRandExplicitInsert { return types.Datum{}, ddl.ErrInvalidAutoRandom.GenWithStackByArgs(autoid.AutoRandomExplicitInsertDisabledErrMsg) } - err = e.rebaseAutoRandomID(recordID, &c.FieldType) + err = e.rebaseAutoRandomID(ctx, recordID, &c.FieldType) if err != nil { return types.Datum{}, err } @@ -936,7 +936,7 @@ func (e *InsertValues) allocAutoRandomID(ctx context.Context, fieldType *types.F return autoRandomID, nil } -func (e *InsertValues) rebaseAutoRandomID(recordID int64, fieldType *types.FieldType) error { +func (e *InsertValues) rebaseAutoRandomID(ctx context.Context, recordID int64, fieldType *types.FieldType) error { if recordID < 0 { return nil } @@ -946,7 +946,7 @@ func (e *InsertValues) rebaseAutoRandomID(recordID int64, fieldType *types.Field layout := autoid.NewShardIDLayout(fieldType, tableInfo.AutoRandomBits) autoRandomID := layout.IncrementalMask() & recordID - return alloc.Rebase(autoRandomID, true) + return alloc.Rebase(ctx, autoRandomID, true) } func (e *InsertValues) adjustImplicitRowID(ctx context.Context, d types.Datum, hasValue bool, c *table.Column) (types.Datum, error) { @@ -963,7 +963,7 @@ func (e *InsertValues) adjustImplicitRowID(ctx context.Context, d types.Datum, h if !e.ctx.GetSessionVars().AllowWriteRowID { return types.Datum{}, errors.Errorf("insert, update and replace statements for _tidb_rowid are not supported.") } - err = e.rebaseImplicitRowID(recordID) + err = e.rebaseImplicitRowID(ctx, recordID) if err != nil { return types.Datum{}, err } @@ -990,7 +990,7 @@ func (e *InsertValues) adjustImplicitRowID(ctx context.Context, d types.Datum, h return d, nil } -func (e *InsertValues) rebaseImplicitRowID(recordID int64) error { +func (e *InsertValues) rebaseImplicitRowID(ctx context.Context, recordID int64) error { if recordID < 0 { return nil } @@ -1000,7 +1000,7 @@ func (e *InsertValues) rebaseImplicitRowID(recordID int64) error { layout := autoid.NewShardIDLayout(types.NewFieldType(mysql.TypeLonglong), tableInfo.ShardRowIDBits) newTiDBRowIDBase := layout.IncrementalMask() & recordID - return alloc.Rebase(newTiDBRowIDBase, true) + return alloc.Rebase(ctx, newTiDBRowIDBase, true) } func (e *InsertValues) handleWarning(err error) { @@ -1013,10 +1013,9 @@ func (e *InsertValues) collectRuntimeStatsEnabled() bool { if e.stats == nil { snapshotStats := &txnsnapshot.SnapshotRuntimeStats{} e.stats = &InsertRuntimeStat{ - BasicRuntimeStats: e.runtimeStats, - SnapshotRuntimeStats: snapshotStats, - Prefetch: 0, - CheckInsertTime: 0, + BasicRuntimeStats: e.runtimeStats, + SnapshotRuntimeStats: snapshotStats, + AllocatorRuntimeStats: autoid.NewAllocatorRuntimeStats(), } e.ctx.GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(e.id, e.stats) } @@ -1140,20 +1139,46 @@ func (e *InsertValues) addRecordWithAutoIDHint(ctx context.Context, row []types. type InsertRuntimeStat struct { *execdetails.BasicRuntimeStats *txnsnapshot.SnapshotRuntimeStats + *autoid.AllocatorRuntimeStats CheckInsertTime time.Duration Prefetch time.Duration } func (e *InsertRuntimeStat) String() string { + buf := bytes.NewBuffer(make([]byte, 0, 32)) + var allocatorStatsStr string + if e.AllocatorRuntimeStats != nil { + allocatorStatsStr = e.AllocatorRuntimeStats.String() + } if e.CheckInsertTime == 0 { // For replace statement. + if allocatorStatsStr != "" { + buf.WriteString(allocatorStatsStr) + } if e.Prefetch > 0 && e.SnapshotRuntimeStats != nil { - return fmt.Sprintf("prefetch: %v, rpc:{%v}", execdetails.FormatDuration(e.Prefetch), e.SnapshotRuntimeStats.String()) + if buf.Len() > 0 { + buf.WriteString(", ") + } + buf.WriteString("prefetch: ") + buf.WriteString(execdetails.FormatDuration(e.Prefetch)) + buf.WriteString(", rpc: {") + buf.WriteString(e.SnapshotRuntimeStats.String()) + buf.WriteString("}") + return buf.String() } return "" } - buf := bytes.NewBuffer(make([]byte, 0, 32)) - buf.WriteString(fmt.Sprintf("prepare:%v, ", execdetails.FormatDuration(time.Duration(e.BasicRuntimeStats.GetTime())-e.CheckInsertTime))) + if allocatorStatsStr != "" { + buf.WriteString("prepare: {total: ") + buf.WriteString(execdetails.FormatDuration(time.Duration(e.BasicRuntimeStats.GetTime()) - e.CheckInsertTime)) + buf.WriteString(", ") + buf.WriteString(allocatorStatsStr) + buf.WriteString("}, ") + } else { + buf.WriteString("prepare: ") + buf.WriteString(execdetails.FormatDuration(time.Duration(e.BasicRuntimeStats.GetTime()) - e.CheckInsertTime)) + buf.WriteString(", ") + } if e.Prefetch > 0 { buf.WriteString(fmt.Sprintf("check_insert: {total_time: %v, mem_insert_time: %v, prefetch: %v", execdetails.FormatDuration(e.CheckInsertTime), @@ -1185,6 +1210,9 @@ func (e *InsertRuntimeStat) Clone() execdetails.RuntimeStats { basicStats := e.BasicRuntimeStats.Clone() newRs.BasicRuntimeStats = basicStats.(*execdetails.BasicRuntimeStats) } + if e.AllocatorRuntimeStats != nil { + newRs.AllocatorRuntimeStats = e.AllocatorRuntimeStats.Clone() + } return newRs } @@ -1210,6 +1238,13 @@ func (e *InsertRuntimeStat) Merge(other execdetails.RuntimeStats) { e.BasicRuntimeStats.Merge(tmp.BasicRuntimeStats) } } + if tmp.AllocatorRuntimeStats != nil { + if e.AllocatorRuntimeStats == nil { + e.AllocatorRuntimeStats = tmp.AllocatorRuntimeStats.Clone() + } else { + e.AllocatorRuntimeStats.Merge(tmp.AllocatorRuntimeStats) + } + } e.Prefetch += tmp.Prefetch e.CheckInsertTime += tmp.CheckInsertTime } diff --git a/executor/insert_test.go b/executor/insert_test.go index 23fd0ba99cba8..7de4702e4e755 100644 --- a/executor/insert_test.go +++ b/executor/insert_test.go @@ -1460,10 +1460,10 @@ func (s *testSuite10) TestInsertRuntimeStat(c *C) { Prefetch: 1 * time.Second, } stats.BasicRuntimeStats.Record(5*time.Second, 1) - c.Assert(stats.String(), Equals, "prepare:3s, check_insert: {total_time: 2s, mem_insert_time: 1s, prefetch: 1s}") + c.Assert(stats.String(), Equals, "prepare: 3s, check_insert: {total_time: 2s, mem_insert_time: 1s, prefetch: 1s}") c.Assert(stats.String(), Equals, stats.Clone().String()) stats.Merge(stats.Clone()) - c.Assert(stats.String(), Equals, "prepare:6s, check_insert: {total_time: 4s, mem_insert_time: 2s, prefetch: 2s}") + c.Assert(stats.String(), Equals, "prepare: 6s, check_insert: {total_time: 4s, mem_insert_time: 2s, prefetch: 2s}") } func (s *testSerialSuite) TestDuplicateEntryMessage(c *C) { diff --git a/executor/main_test.go b/executor/main_test.go index b5ac20bf97e21..cb02b1576e59f 100644 --- a/executor/main_test.go +++ b/executor/main_test.go @@ -15,14 +15,30 @@ package executor import ( + "os" "testing" + "github.com/pingcap/tidb/config" + "github.com/pingcap/tidb/meta/autoid" "github.com/pingcap/tidb/util/testbridge" + "github.com/tikv/client-go/v2/tikv" "go.uber.org/goleak" ) func TestMain(m *testing.M) { testbridge.WorkaroundGoCheckFlags() + + autoid.SetStep(5000) + config.UpdateGlobal(func(conf *config.Config) { + conf.Log.SlowThreshold = 30000 // 30s + conf.TiKVClient.AsyncCommit.SafeWindow = 0 + conf.TiKVClient.AsyncCommit.AllowedClockDrift = 0 + }) + tikv.EnableFailpoints() + tmpDir := config.GetGlobalConfig().TempStoragePath + _ = os.RemoveAll(tmpDir) // clean the uncleared temp file during the last run. + _ = os.MkdirAll(tmpDir, 0755) + opts := []goleak.Option{ goleak.IgnoreTopFunction("go.etcd.io/etcd/pkg/logutil.(*MergeLogger).outputLoop"), goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"), diff --git a/executor/replace.go b/executor/replace.go index cf96ec99320bd..5531b55599579 100644 --- a/executor/replace.go +++ b/executor/replace.go @@ -24,6 +24,7 @@ import ( "github.com/pingcap/parser/charset" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/meta/autoid" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/tablecodec" "github.com/pingcap/tidb/types" @@ -250,6 +251,10 @@ func (e *ReplaceExec) exec(ctx context.Context, newRows [][]types.Datum) error { // Next implements the Executor Next interface. func (e *ReplaceExec) Next(ctx context.Context, req *chunk.Chunk) error { req.Reset() + if e.collectRuntimeStatsEnabled() { + ctx = context.WithValue(ctx, autoid.AllocatorRuntimeStatsCtxKey, e.stats.AllocatorRuntimeStats) + } + if len(e.children) > 0 && e.children[0] != nil { return insertRowsFromSelect(ctx, e) } diff --git a/executor/show.go b/executor/show.go index b22c7bf4543aa..4c774fe9daa76 100644 --- a/executor/show.go +++ b/executor/show.go @@ -1071,13 +1071,13 @@ func ConstructResultOfShowCreateTable(ctx sessionctx.Context, tableInfo *model.T } if tableInfo.PlacementPolicyRef != nil { - fmt.Fprintf(buf, " /*T![placement] PLACEMENT POLICY=`%s` */", tableInfo.PlacementPolicyRef.Name.String()) + fmt.Fprintf(buf, " /*T![placement] PLACEMENT POLICY=%s */", stringutil.Escape(tableInfo.PlacementPolicyRef.Name.O, sqlMode)) } // add direct placement info here appendDirectPlacementInfo(tableInfo.DirectPlacementOpts, buf) // add partition info here. - appendPartitionInfo(tableInfo.Partition, buf) + appendPartitionInfo(tableInfo.Partition, buf, sqlMode) return nil } @@ -1228,7 +1228,7 @@ func appendDirectPlacementInfo(directPlacementOpts *model.PlacementSettings, buf fmt.Fprintf(buf, " */") } -func appendPartitionInfo(partitionInfo *model.PartitionInfo, buf *bytes.Buffer) { +func appendPartitionInfo(partitionInfo *model.PartitionInfo, buf *bytes.Buffer, sqlMode mysql.SQLMode) { if partitionInfo == nil { return } @@ -1267,6 +1267,14 @@ func appendPartitionInfo(partitionInfo *model.PartitionInfo, buf *bytes.Buffer) for i, def := range partitionInfo.Definitions { lessThans := strings.Join(def.LessThan, ",") fmt.Fprintf(buf, " PARTITION `%s` VALUES LESS THAN (%s)", def.Name, lessThans) + if def.DirectPlacementOpts != nil { + // add direct placement info here + appendDirectPlacementInfo(def.DirectPlacementOpts, buf) + } + if def.PlacementPolicyRef != nil { + // add placement ref info here + fmt.Fprintf(buf, " /*T![placement] PLACEMENT POLICY=%s */", stringutil.Escape(def.PlacementPolicyRef.Name.O, sqlMode)) + } if i < len(partitionInfo.Definitions)-1 { buf.WriteString(",\n") } else { @@ -1290,6 +1298,14 @@ func appendPartitionInfo(partitionInfo *model.PartitionInfo, buf *bytes.Buffer) } } fmt.Fprintf(buf, " PARTITION `%s` VALUES IN (%s)", def.Name, values.String()) + if def.DirectPlacementOpts != nil { + // add direct placement info here + appendDirectPlacementInfo(def.DirectPlacementOpts, buf) + } + if def.PlacementPolicyRef != nil { + // add placement ref info here + fmt.Fprintf(buf, " /*T![placement] PLACEMENT POLICY=%s */", stringutil.Escape(def.PlacementPolicyRef.Name.O, sqlMode)) + } if i < len(partitionInfo.Definitions)-1 { buf.WriteString(",\n") } else { @@ -1332,10 +1348,12 @@ func ConstructResultOfShowCreateDatabase(ctx sessionctx.Context, dbInfo *model.D // MySQL 5.7 always show the charset info but TiDB may ignore it, which makes a slight difference. We keep this // behavior unchanged because it is trivial enough. if dbInfo.DirectPlacementOpts != nil { - fmt.Fprintf(buf, " %s", dbInfo.DirectPlacementOpts) + // add direct placement info here + appendDirectPlacementInfo(dbInfo.DirectPlacementOpts, buf) } if dbInfo.PlacementPolicyRef != nil { - fmt.Fprintf(buf, " PLACEMENT POLICY = %s", stringutil.Escape(dbInfo.PlacementPolicyRef.Name.O, sqlMode)) + // add placement ref info here + fmt.Fprintf(buf, " /*T![placement] PLACEMENT POLICY=%s */", stringutil.Escape(dbInfo.PlacementPolicyRef.Name.O, sqlMode)) } return nil } diff --git a/executor/show_test.go b/executor/show_test.go index 3ded6078fb05b..527ef4441e32e 100644 --- a/executor/show_test.go +++ b/executor/show_test.go @@ -151,13 +151,13 @@ func (s *testSuite5) TestShowWarningsForExprPushdown(c *C) { tk.MustExec(testSQL) tk.MustExec("explain select * from show_warnings_expr_pushdown where date_add(value, interval 1 day) = '2020-01-01'") c.Assert(tk.Se.GetSessionVars().StmtCtx.WarningCount(), Equals, uint16(1)) - tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1105|Scalar function 'date_add'(signature: AddDateDatetimeInt) can not be pushed to tikv")) + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1105|Scalar function 'date_add'(signature: AddDateDatetimeInt, return type: datetime(6)) can not be pushed to tikv")) tk.MustExec("explain select max(date_add(value, interval 1 day)) from show_warnings_expr_pushdown group by a") c.Assert(tk.Se.GetSessionVars().StmtCtx.WarningCount(), Equals, uint16(2)) - tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1105|Scalar function 'date_add'(signature: AddDateDatetimeInt) can not be pushed to tikv", "Warning|1105|Aggregation can not be pushed to tikv because arguments of AggFunc `max` contains unsupported exprs")) + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1105|Scalar function 'date_add'(signature: AddDateDatetimeInt, return type: datetime(6)) can not be pushed to tikv", "Warning|1105|Aggregation can not be pushed to tikv because arguments of AggFunc `max` contains unsupported exprs")) tk.MustExec("explain select max(a) from show_warnings_expr_pushdown group by date_add(value, interval 1 day)") c.Assert(tk.Se.GetSessionVars().StmtCtx.WarningCount(), Equals, uint16(2)) - tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1105|Scalar function 'date_add'(signature: AddDateDatetimeInt) can not be pushed to tikv", "Warning|1105|Aggregation can not be pushed to tikv because groupByItems contain unsupported exprs")) + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1105|Scalar function 'date_add'(signature: AddDateDatetimeInt, return type: datetime(6)) can not be pushed to tikv", "Warning|1105|Aggregation can not be pushed to tikv because groupByItems contain unsupported exprs")) tk.MustExec("set tidb_opt_distinct_agg_push_down=0") tk.MustExec("explain select max(distinct a) from show_warnings_expr_pushdown group by value") c.Assert(tk.Se.GetSessionVars().StmtCtx.WarningCount(), Equals, uint16(1)) @@ -947,29 +947,24 @@ func (s *testSuite5) TestShowCreateTable(c *C) { func (s *testAutoRandomSuite) TestShowCreateTablePlacement(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") + defer tk.MustExec(`DROP TABLE IF EXISTS t`) // case for direct opts tk.MustExec(`DROP TABLE IF EXISTS t`) tk.MustExec("create table t(a int) " + - "PRIMARY_REGION=\"cn-east-1\" " + - "REGIONS=\"cn-east-1, cn-east-2\" " + "FOLLOWERS=2 " + "CONSTRAINTS=\"[+disk=ssd]\"") tk.MustQuery(`show create table t`).Check(testutil.RowsWithSep("|", "t CREATE TABLE `t` (\n"+ " `a` int(11) DEFAULT NULL\n"+ ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin "+ - "/*T![placement] PRIMARY_REGION=\"cn-east-1\" "+ - "REGIONS=\"cn-east-1, cn-east-2\" "+ - "FOLLOWERS=2 "+ + "/*T![placement] FOLLOWERS=2 "+ "CONSTRAINTS=\"[+disk=ssd]\" */", )) // case for policy tk.MustExec(`DROP TABLE IF EXISTS t`) tk.MustExec("create placement policy x " + - "PRIMARY_REGION=\"cn-east-1\" " + - "REGIONS=\"cn-east-1, cn-east-2\" " + "FOLLOWERS=2 " + "CONSTRAINTS=\"[+disk=ssd]\" ") tk.MustExec("create table t(a int)" + @@ -981,7 +976,16 @@ func (s *testAutoRandomSuite) TestShowCreateTablePlacement(c *C) { "/*T![placement] PLACEMENT POLICY=`x` */", )) + // case for policy with quotes tk.MustExec(`DROP TABLE IF EXISTS t`) + tk.MustExec("create table t(a int)" + + "/*T![placement] PLACEMENT POLICY=\"x\" */") + tk.MustQuery(`show create table t`).Check(testutil.RowsWithSep("|", + "t CREATE TABLE `t` (\n"+ + " `a` int(11) DEFAULT NULL\n"+ + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin "+ + "/*T![placement] PLACEMENT POLICY=`x` */", + )) } func (s *testAutoRandomSuite) TestShowCreateTableAutoRandom(c *C) { diff --git a/executor/simple.go b/executor/simple.go index f4fb4715e3a90..ab5156d787b7e 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -1342,7 +1342,7 @@ func (e *SimpleExec) userAuthPlugin(name string, host string) (string, error) { func (e *SimpleExec) executeSetPwd(ctx context.Context, s *ast.SetPwdStmt) error { var u, h string - if s.User == nil { + if s.User == nil || s.User.CurrentUser { if e.ctx.GetSessionVars().User == nil { return errors.New("Session error is empty") } diff --git a/executor/simple_test.go b/executor/simple_test.go index bcea09d715c0b..b24fc89f9fd57 100644 --- a/executor/simple_test.go +++ b/executor/simple_test.go @@ -889,3 +889,18 @@ func (s *testSuite3) TestIssue23649(c *C) { _, err = tk.Exec("GRANT bogusrole to nonexisting;") c.Assert(err.Error(), Equals, "[executor:3523]Unknown authorization ID `bogusrole`@`%`") } + +func (s *testSuite3) TestSetCurrentUserPwd(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("CREATE USER issue28534;") + defer func() { + tk.MustExec("DROP USER IF EXISTS issue28534;") + }() + + c.Assert(tk.Se.Auth(&auth.UserIdentity{Username: "issue28534", Hostname: "localhost", CurrentUser: true, AuthUsername: "issue28534", AuthHostname: "%"}, nil, nil), IsTrue) + tk.MustExec(`SET PASSWORD FOR CURRENT_USER() = "43582eussi"`) + + c.Assert(tk.Se.Auth(&auth.UserIdentity{Username: "root", Hostname: "%"}, nil, nil), IsTrue) + result := tk.MustQuery(`SELECT authentication_string FROM mysql.User WHERE User="issue28534"`) + result.Check(testkit.Rows(auth.EncodePassword("43582eussi"))) +} diff --git a/executor/tiflash_test.go b/executor/tiflash_test.go index 12cc111b9a561..f7331c05b9e73 100644 --- a/executor/tiflash_test.go +++ b/executor/tiflash_test.go @@ -17,6 +17,7 @@ package executor_test import ( "bytes" "fmt" + "math" "math/rand" "strings" "sync" @@ -32,10 +33,12 @@ import ( "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/executor" "github.com/pingcap/tidb/kv" + plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/store/mockstore" "github.com/pingcap/tidb/store/mockstore/unistore" "github.com/pingcap/tidb/util/israce" + "github.com/pingcap/tidb/util/kvcache" "github.com/pingcap/tidb/util/testkit" "github.com/pingcap/tidb/util/testleak" "github.com/tikv/client-go/v2/testutils" @@ -544,6 +547,57 @@ func (s *tiflashTestSuite) TestMppEnum(c *C) { tk.MustQuery("select t1.b from t t1 join t t2 on t1.a = t2.a order by t1.b").Check(testkit.Rows("aca", "bca", "zca")) } +func (s *tiflashTestSuite) TestTiFlashPlanCacheable(c *C) { + tk := testkit.NewTestKit(c, s.store) + orgEnable := plannercore.PreparedPlanCacheEnabled() + defer func() { + plannercore.SetPreparedPlanCache(orgEnable) + }() + plannercore.SetPreparedPlanCache(true) + + var err error + tk.Se, err = session.CreateSession4TestWithOpt(s.store, &session.Opt{ + PreparedPlanCache: kvcache.NewSimpleLRUCache(100, 0.1, math.MaxUint64), + }) + c.Assert(err, IsNil) + + tk.MustExec("use test;") + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t(a int);") + tk.MustExec("set @@tidb_enable_collect_execution_info=0;") + tk.MustExec("alter table test.t set tiflash replica 1") + tb := testGetTableByName(c, tk.Se, "test", "t") + err = domain.GetDomain(tk.Se).DDL().UpdateTableReplicaInfo(tk.Se, tb.Meta().ID, true) + c.Assert(err, IsNil) + tk.MustExec("set @@session.tidb_isolation_read_engines = 'tikv, tiflash'") + tk.MustExec("insert into t values(1);") + tk.MustExec("prepare stmt from 'select /*+ read_from_storage(tiflash[t]) */ * from t;';") + tk.MustQuery("execute stmt;").Check(testkit.Rows("1")) + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0")) + tk.MustQuery("execute stmt;").Check(testkit.Rows("1")) + // The TiFlash plan can not be cached. + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0")) + + tk.MustExec("prepare stmt from 'select /*+ read_from_storage(tikv[t]) */ * from t;';") + tk.MustQuery("execute stmt;").Check(testkit.Rows("1")) + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0")) + tk.MustQuery("execute stmt;").Check(testkit.Rows("1")) + // The TiKV plan can be cached. + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) + + // test the mpp plan + tk.MustExec("set @@session.tidb_allow_mpp = 1;") + tk.MustExec("set @@session.tidb_enforce_mpp = 1;") + tk.MustExec("prepare stmt from 'select count(t1.a) from t t1 join t t2 on t1.a = t2.a where t1.a > ?;';") + tk.MustExec("set @a = 0;") + tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("1")) + + tk.MustExec("set @a = 1;") + tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("0")) + // The TiFlash plan can not be cached. + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0")) +} + func (s *tiflashTestSuite) TestDispatchTaskRetry(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") diff --git a/executor/update.go b/executor/update.go index 5b946958302ab..06253479996c6 100644 --- a/executor/update.go +++ b/executor/update.go @@ -15,6 +15,7 @@ package executor import ( + "bytes" "context" "fmt" "runtime/trace" @@ -23,11 +24,13 @@ import ( "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/meta/autoid" plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/execdetails" "github.com/pingcap/tidb/util/memory" "github.com/tikv/client-go/v2/txnkv/txnsnapshot" ) @@ -58,7 +61,7 @@ type UpdateExec struct { drained bool memTracker *memory.Tracker - stats *runtimeStatsWithSnapshot + stats *updateRuntimeStats handles []kv.Handle tableUpdatable []bool @@ -217,6 +220,9 @@ func (e *UpdateExec) unmatchedOuterRow(tblPos plannercore.TblColPosInfo, waitUpd func (e *UpdateExec) Next(ctx context.Context, req *chunk.Chunk) error { req.Reset() if !e.drained { + if e.collectRuntimeStatsEnabled() { + ctx = context.WithValue(ctx, autoid.AllocatorRuntimeStatsCtxKey, e.stats.AllocatorRuntimeStats) + } numRows, err := e.updateRows(ctx) if err != nil { return err @@ -414,7 +420,7 @@ func (e *UpdateExec) Close() error { e.setMessage() if e.runtimeStats != nil && e.stats != nil { txn, err := e.ctx.Txn(false) - if err == nil && txn.GetSnapshot() != nil { + if err == nil && txn.Valid() && txn.GetSnapshot() != nil { txn.GetSnapshot().SetOption(kv.CollectRuntimeStats, nil) } } @@ -442,9 +448,9 @@ func (e *UpdateExec) setMessage() { func (e *UpdateExec) collectRuntimeStatsEnabled() bool { if e.runtimeStats != nil { if e.stats == nil { - snapshotStats := &txnsnapshot.SnapshotRuntimeStats{} - e.stats = &runtimeStatsWithSnapshot{ - SnapshotRuntimeStats: snapshotStats, + e.stats = &updateRuntimeStats{ + SnapshotRuntimeStats: &txnsnapshot.SnapshotRuntimeStats{}, + AllocatorRuntimeStats: autoid.NewAllocatorRuntimeStats(), } e.ctx.GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(e.id, e.stats) } @@ -452,3 +458,71 @@ func (e *UpdateExec) collectRuntimeStatsEnabled() bool { } return false } + +// updateRuntimeStats is the execution stats about update statements. +type updateRuntimeStats struct { + *txnsnapshot.SnapshotRuntimeStats + *autoid.AllocatorRuntimeStats +} + +func (e *updateRuntimeStats) String() string { + if e.SnapshotRuntimeStats == nil && e.AllocatorRuntimeStats == nil { + return "" + } + buf := bytes.NewBuffer(make([]byte, 0, 16)) + if e.SnapshotRuntimeStats != nil { + stats := e.SnapshotRuntimeStats.String() + if stats != "" { + buf.WriteString(stats) + } + } + if e.AllocatorRuntimeStats != nil { + stats := e.AllocatorRuntimeStats.String() + if stats != "" { + if buf.Len() > 0 { + buf.WriteString(", ") + } + buf.WriteString(stats) + } + } + return buf.String() +} + +// Clone implements the RuntimeStats interface. +func (e *updateRuntimeStats) Clone() execdetails.RuntimeStats { + newRs := &updateRuntimeStats{} + if e.SnapshotRuntimeStats != nil { + snapshotStats := e.SnapshotRuntimeStats.Clone() + newRs.SnapshotRuntimeStats = snapshotStats + } + if e.AllocatorRuntimeStats != nil { + newRs.AllocatorRuntimeStats = e.AllocatorRuntimeStats.Clone() + } + return newRs +} + +// Merge implements the RuntimeStats interface. +func (e *updateRuntimeStats) Merge(other execdetails.RuntimeStats) { + tmp, ok := other.(*updateRuntimeStats) + if !ok { + return + } + if tmp.SnapshotRuntimeStats != nil { + if e.SnapshotRuntimeStats == nil { + snapshotStats := tmp.SnapshotRuntimeStats.Clone() + e.SnapshotRuntimeStats = snapshotStats + } else { + e.SnapshotRuntimeStats.Merge(tmp.SnapshotRuntimeStats) + } + } + if tmp.AllocatorRuntimeStats != nil { + if e.AllocatorRuntimeStats == nil { + e.AllocatorRuntimeStats = tmp.AllocatorRuntimeStats.Clone() + } + } +} + +// Tp implements the RuntimeStats interface. +func (e *updateRuntimeStats) Tp() int { + return execdetails.TpUpdateRuntimeStats +} diff --git a/executor/write.go b/executor/write.go index 28925957bd764..835780c947214 100644 --- a/executor/write.go +++ b/executor/write.go @@ -111,14 +111,14 @@ func updateRecord(ctx context.Context, sctx sessionctx.Context, h kv.Handle, old if err != nil { return false, err } - if err = t.Allocators(sctx).Get(autoid.RowIDAllocType).Rebase(recordID, true); err != nil { + if err = t.Allocators(sctx).Get(autoid.RowIDAllocType).Rebase(ctx, recordID, true); err != nil { return false, err } } if col.IsPKHandleColumn(t.Meta()) { handleChanged = true // Rebase auto random id if the field is changed. - if err := rebaseAutoRandomValue(sctx, t, &newData[i], col); err != nil { + if err := rebaseAutoRandomValue(ctx, sctx, t, &newData[i], col); err != nil { return false, err } } @@ -222,7 +222,7 @@ func updateRecord(ctx context.Context, sctx sessionctx.Context, h kv.Handle, old return true, nil } -func rebaseAutoRandomValue(sctx sessionctx.Context, t table.Table, newData *types.Datum, col *table.Column) error { +func rebaseAutoRandomValue(ctx context.Context, sctx sessionctx.Context, t table.Table, newData *types.Datum, col *table.Column) error { tableInfo := t.Meta() if !tableInfo.ContainsAutoRandomBits() { return nil @@ -237,7 +237,7 @@ func rebaseAutoRandomValue(sctx sessionctx.Context, t table.Table, newData *type layout := autoid.NewShardIDLayout(&col.FieldType, tableInfo.AutoRandomBits) // Set bits except incremental_bits to zero. recordID = recordID & (1< mysql.MaxDecimalWidth { tp.Flen = mysql.MaxDecimalWidth } types.SetBinChsClnFlag(tp) diff --git a/expression/expr_to_pb_test.go b/expression/expr_to_pb_test.go index 0f93dbadcd3d6..70514ac92dd85 100644 --- a/expression/expr_to_pb_test.go +++ b/expression/expr_to_pb_test.go @@ -609,6 +609,8 @@ func (s *testEvaluatorSuite) TestExprPushDownToFlash(c *C) { datetimeColumn := dg.genColumn(mysql.TypeDatetime, 6) binaryStringColumn := dg.genColumn(mysql.TypeString, 7) binaryStringColumn.RetType.Collate = charset.CollationBin + int32Column := dg.genColumn(mysql.TypeLong, 8) + float32Column := dg.genColumn(mysql.TypeFloat, 9) function, err := NewFunction(mock.NewContext(), ast.JSONLength, types.NewFieldType(mysql.TypeLonglong), jsonColumn) c.Assert(err, IsNil) @@ -656,28 +658,31 @@ func (s *testEvaluatorSuite) TestExprPushDownToFlash(c *C) { c.Assert(err, IsNil) exprs = append(exprs, function) + validDecimalType := types.NewFieldType(mysql.TypeNewDecimal) + validDecimalType.Flen = 20 + validDecimalType.Decimal = 2 // CastIntAsDecimal - function, err = NewFunction(mock.NewContext(), ast.Cast, types.NewFieldType(mysql.TypeNewDecimal), intColumn) + function, err = NewFunction(mock.NewContext(), ast.Cast, validDecimalType, intColumn) c.Assert(err, IsNil) exprs = append(exprs, function) // CastRealAsDecimal - function, err = NewFunction(mock.NewContext(), ast.Cast, types.NewFieldType(mysql.TypeNewDecimal), realColumn) + function, err = NewFunction(mock.NewContext(), ast.Cast, validDecimalType, realColumn) c.Assert(err, IsNil) exprs = append(exprs, function) // CastDecimalAsDecimal - function, err = NewFunction(mock.NewContext(), ast.Cast, types.NewFieldType(mysql.TypeNewDecimal), decimalColumn) + function, err = NewFunction(mock.NewContext(), ast.Cast, validDecimalType, decimalColumn) c.Assert(err, IsNil) exprs = append(exprs, function) // CastStringAsDecimal - function, err = NewFunction(mock.NewContext(), ast.Cast, types.NewFieldType(mysql.TypeNewDecimal), stringColumn) + function, err = NewFunction(mock.NewContext(), ast.Cast, validDecimalType, stringColumn) c.Assert(err, IsNil) exprs = append(exprs, function) // CastTimeAsDecimal - function, err = NewFunction(mock.NewContext(), ast.Cast, types.NewFieldType(mysql.TypeNewDecimal), datetimeColumn) + function, err = NewFunction(mock.NewContext(), ast.Cast, validDecimalType, datetimeColumn) c.Assert(err, IsNil) exprs = append(exprs, function) @@ -961,6 +966,16 @@ func (s *testEvaluatorSuite) TestExprPushDownToFlash(c *C) { c.Assert(function.(*ScalarFunction).Function.PbCode(), Equals, tipb.ScalarFuncSig_StrToDateDatetime) exprs = append(exprs, function) + // cast Int32 to Int32 + function, err = NewFunction(mock.NewContext(), ast.Cast, types.NewFieldType(mysql.TypeLong), int32Column) + c.Assert(err, IsNil) + exprs = append(exprs, function) + + // cast float32 to float32 + function, err = NewFunction(mock.NewContext(), ast.Cast, types.NewFieldType(mysql.TypeFloat), float32Column) + c.Assert(err, IsNil) + exprs = append(exprs, function) + canPush := CanExprsPushDown(sc, exprs, client, kv.TiFlash) c.Assert(canPush, Equals, true) @@ -985,6 +1000,28 @@ func (s *testEvaluatorSuite) TestExprPushDownToFlash(c *C) { c.Assert(err, IsNil) exprs = append(exprs, function) + // Cast to Int32: not supported + function, err = NewFunction(mock.NewContext(), ast.Cast, types.NewFieldType(mysql.TypeLong), stringColumn) + c.Assert(err, IsNil) + exprs = append(exprs, function) + + // Cast to Float: not supported + function, err = NewFunction(mock.NewContext(), ast.Cast, types.NewFieldType(mysql.TypeFloat), intColumn) + c.Assert(err, IsNil) + exprs = append(exprs, function) + + // Cast to invalid Decimal Type: not supported + function, err = NewFunction(mock.NewContext(), ast.Cast, types.NewFieldType(mysql.TypeNewDecimal), intColumn) + c.Assert(err, IsNil) + exprs = append(exprs, function) + + // cast Int32 to UInt32 + unsignedInt32Type := types.NewFieldType(mysql.TypeLong) + unsignedInt32Type.Flag = mysql.UnsignedFlag + function, err = NewFunction(mock.NewContext(), ast.Cast, unsignedInt32Type, int32Column) + c.Assert(err, IsNil) + exprs = append(exprs, function) + pushed, remained := PushDownExprs(sc, exprs, client, kv.TiFlash) c.Assert(len(pushed), Equals, 0) c.Assert(len(remained), Equals, len(exprs)) diff --git a/expression/expression.go b/expression/expression.go index 0e574ee7dc375..b873d880d481d 100644 --- a/expression/expression.go +++ b/expression/expression.go @@ -1004,6 +1004,13 @@ func scalarExprSupportedByTiKV(sf *ScalarFunction) bool { return false } +func isValidTiFlashDecimalType(tp *types.FieldType) bool { + if tp.Tp != mysql.TypeNewDecimal { + return false + } + return tp.Flen > 0 && tp.Flen <= 65 && tp.Decimal >= 0 && tp.Decimal <= 30 && tp.Flen >= tp.Decimal +} + func scalarExprSupportedByFlash(function *ScalarFunction) bool { switch function.FuncName.L { case ast.Floor, ast.Ceil, ast.Ceiling: @@ -1040,16 +1047,27 @@ func scalarExprSupportedByFlash(function *ScalarFunction) bool { return true } case ast.Cast: + sourceType := function.GetArgs()[0].GetType() + retType := function.RetType switch function.Function.PbCode() { - case tipb.ScalarFuncSig_CastIntAsTime: + case tipb.ScalarFuncSig_CastDecimalAsInt, tipb.ScalarFuncSig_CastIntAsInt, tipb.ScalarFuncSig_CastRealAsInt, tipb.ScalarFuncSig_CastTimeAsInt, + tipb.ScalarFuncSig_CastStringAsInt /*, tipb.ScalarFuncSig_CastDurationAsInt, tipb.ScalarFuncSig_CastJsonAsInt*/ : + // TiFlash cast only support cast to Int64 or the source type is the same as the target type + return (sourceType.Tp == retType.Tp && mysql.HasUnsignedFlag(sourceType.Flag) == mysql.HasUnsignedFlag(retType.Flag)) || retType.Tp == mysql.TypeLonglong + case tipb.ScalarFuncSig_CastIntAsReal, tipb.ScalarFuncSig_CastRealAsReal, tipb.ScalarFuncSig_CastStringAsReal: /*, tipb.ScalarFuncSig_CastDecimalAsReal, + tipb.ScalarFuncSig_CastDurationAsReal, tipb.ScalarFuncSig_CastTimeAsReal, tipb.ScalarFuncSig_CastJsonAsReal*/ + // TiFlash cast only support cast to Float64 or the source type is the same as the target type + return sourceType.Tp == retType.Tp || retType.Tp == mysql.TypeDouble + case tipb.ScalarFuncSig_CastDecimalAsDecimal, tipb.ScalarFuncSig_CastIntAsDecimal, tipb.ScalarFuncSig_CastRealAsDecimal, tipb.ScalarFuncSig_CastTimeAsDecimal, + tipb.ScalarFuncSig_CastStringAsDecimal /*, tipb.ScalarFuncSig_CastDurationAsDecimal, tipb.ScalarFuncSig_CastJsonAsDecimal*/ : + return isValidTiFlashDecimalType(function.RetType) + case tipb.ScalarFuncSig_CastDecimalAsString, tipb.ScalarFuncSig_CastIntAsString, tipb.ScalarFuncSig_CastRealAsString, tipb.ScalarFuncSig_CastTimeAsString, + tipb.ScalarFuncSig_CastStringAsString /*, tipb.ScalarFuncSig_CastDurationAsString, tipb.ScalarFuncSig_CastJsonAsString*/ : + return true + case tipb.ScalarFuncSig_CastDecimalAsTime, tipb.ScalarFuncSig_CastIntAsTime, tipb.ScalarFuncSig_CastRealAsTime, tipb.ScalarFuncSig_CastTimeAsTime, + tipb.ScalarFuncSig_CastStringAsTime /*, tipb.ScalarFuncSig_CastDurationAsTime, tipb.ScalarFuncSig_CastJsonAsTime*/ : // ban the function of casting year type as time type pushing down to tiflash because of https://github.com/pingcap/tidb/issues/26215 return function.GetArgs()[0].GetType().Tp != mysql.TypeYear - case tipb.ScalarFuncSig_CastIntAsInt, tipb.ScalarFuncSig_CastIntAsReal, tipb.ScalarFuncSig_CastIntAsDecimal, tipb.ScalarFuncSig_CastIntAsString, - tipb.ScalarFuncSig_CastRealAsInt, tipb.ScalarFuncSig_CastRealAsReal, tipb.ScalarFuncSig_CastRealAsDecimal, tipb.ScalarFuncSig_CastRealAsString, tipb.ScalarFuncSig_CastRealAsTime, - tipb.ScalarFuncSig_CastStringAsInt, tipb.ScalarFuncSig_CastStringAsReal, tipb.ScalarFuncSig_CastStringAsDecimal, tipb.ScalarFuncSig_CastStringAsString, tipb.ScalarFuncSig_CastStringAsTime, - tipb.ScalarFuncSig_CastDecimalAsInt /*, tipb.ScalarFuncSig_CastDecimalAsReal*/, tipb.ScalarFuncSig_CastDecimalAsDecimal, tipb.ScalarFuncSig_CastDecimalAsString, tipb.ScalarFuncSig_CastDecimalAsTime, - tipb.ScalarFuncSig_CastTimeAsInt /*, tipb.ScalarFuncSig_CastTimeAsReal*/, tipb.ScalarFuncSig_CastTimeAsDecimal, tipb.ScalarFuncSig_CastTimeAsTime, tipb.ScalarFuncSig_CastTimeAsString: - return true } case ast.DateAdd, ast.AddDate: switch function.Function.PbCode() { @@ -1183,7 +1201,7 @@ func canScalarFuncPushDown(scalarFunc *ScalarFunction, pc PbConverter, storeType if storeType == kv.UnSpecified { storageName = "storage layer" } - pc.sc.AppendWarning(errors.New("Scalar function '" + scalarFunc.FuncName.L + "'(signature: " + scalarFunc.Function.PbCode().String() + ") can not be pushed to " + storageName)) + pc.sc.AppendWarning(errors.New("Scalar function '" + scalarFunc.FuncName.L + "'(signature: " + scalarFunc.Function.PbCode().String() + ", return type: " + scalarFunc.RetType.CompactStr() + ") can not be pushed to " + storageName)) } return false } diff --git a/go.mod b/go.mod index 7787e47638e2a..845066625d50e 100644 --- a/go.mod +++ b/go.mod @@ -49,7 +49,7 @@ require ( github.com/pingcap/fn v0.0.0-20200306044125-d5540d389059 github.com/pingcap/kvproto v0.0.0-20210806074406-317f69fb54b4 github.com/pingcap/log v0.0.0-20210906054005-afc726e70354 - github.com/pingcap/parser v0.0.0-20210917114242-ac711116bdff + github.com/pingcap/parser v0.0.0-20211004012448-687005894c4e github.com/pingcap/sysutil v0.0.0-20210730114356-fcd8a63f68c5 github.com/pingcap/tidb-tools v5.0.3+incompatible github.com/pingcap/tipb v0.0.0-20210802080519-94b831c6db55 diff --git a/go.sum b/go.sum index 145d1d5ecbdf5..8e438145e44d6 100644 --- a/go.sum +++ b/go.sum @@ -603,8 +603,8 @@ github.com/pingcap/log v0.0.0-20210625125904-98ed8e2eb1c7/go.mod h1:8AanEdAHATuR github.com/pingcap/log v0.0.0-20210906054005-afc726e70354 h1:SvWCbCPh1YeHd9yQLksvJYAgft6wLTY1aNG81tpyscQ= github.com/pingcap/log v0.0.0-20210906054005-afc726e70354/go.mod h1:DWQW5jICDR7UJh4HtxXSM20Churx4CQL0fwL/SoOSA4= github.com/pingcap/parser v0.0.0-20210525032559-c37778aff307/go.mod h1:xZC8I7bug4GJ5KtHhgAikjTfU4kBv1Sbo3Pf1MZ6lVw= -github.com/pingcap/parser v0.0.0-20210917114242-ac711116bdff h1:LiwvvutmyeSkFkdVM09mH6KK+OeDVJzX7WKy9Lf0ri0= -github.com/pingcap/parser v0.0.0-20210917114242-ac711116bdff/go.mod h1:+xcMiiZzdIktT/Nqdfm81dkECJ2EPuoAYywd57py4Pk= +github.com/pingcap/parser v0.0.0-20211004012448-687005894c4e h1:dPMDpj+7ng9qEWoT3n6qjpB1ohz79uTLVM6ILW+ZMT0= +github.com/pingcap/parser v0.0.0-20211004012448-687005894c4e/go.mod h1:+xcMiiZzdIktT/Nqdfm81dkECJ2EPuoAYywd57py4Pk= github.com/pingcap/sysutil v0.0.0-20200206130906-2bfa6dc40bcd/go.mod h1:EB/852NMQ+aRKioCpToQ94Wl7fktV+FNnxf3CX/TTXI= github.com/pingcap/sysutil v0.0.0-20210315073920-cc0985d983a3/go.mod h1:tckvA041UWP+NqYzrJ3fMgC/Hw9wnmQ/tUkp/JaHly8= github.com/pingcap/sysutil v0.0.0-20210730114356-fcd8a63f68c5 h1:7rvAtZe/ZUzOKzgriNPQoBNvleJXBk4z7L3Z47+tS98= diff --git a/meta/autoid/autoid.go b/meta/autoid/autoid.go index 0700b8ef35a38..04bcaadbccf8f 100644 --- a/meta/autoid/autoid.go +++ b/meta/autoid/autoid.go @@ -15,8 +15,10 @@ package autoid import ( + "bytes" "context" "math" + "strconv" "sync" "time" @@ -30,7 +32,10 @@ import ( "github.com/pingcap/tidb/meta" "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/execdetails" "github.com/pingcap/tidb/util/logutil" + "github.com/tikv/client-go/v2/txnkv/txnsnapshot" + tikvutil "github.com/tikv/client-go/v2/util" "go.uber.org/zap" ) @@ -142,7 +147,7 @@ type Allocator interface { // Rebase rebases the autoID base for table with tableID and the new base value. // If allocIDs is true, it will allocate some IDs and save to the cache. // If allocIDs is false, it will not allocate IDs. - Rebase(newBase int64, allocIDs bool) error + Rebase(ctx context.Context, newBase int64, allocIDs bool) error // ForceRebase set the next global auto ID to newBase. ForceRebase(newBase int64) error @@ -244,7 +249,7 @@ func (alloc *allocator) NextGlobalAutoID() (int64, error) { return autoID + 1, err } -func (alloc *allocator) rebase4Unsigned(requiredBase uint64, allocIDs bool) error { +func (alloc *allocator) rebase4Unsigned(ctx context.Context, requiredBase uint64, allocIDs bool) error { // Satisfied by alloc.base, nothing to do. if requiredBase <= uint64(alloc.base) { return nil @@ -254,9 +259,22 @@ func (alloc *allocator) rebase4Unsigned(requiredBase uint64, allocIDs bool) erro alloc.base = int64(requiredBase) return nil } + + ctx, allocatorStats, commitDetail := getAllocatorStatsFromCtx(ctx) + if allocatorStats != nil { + allocatorStats.rebaseCount++ + defer func() { + if commitDetail != nil { + allocatorStats.mergeCommitDetail(*commitDetail) + } + }() + } var newBase, newEnd uint64 startTime := time.Now() - err := kv.RunInNewTxn(context.Background(), alloc.store, true, func(ctx context.Context, txn kv.Transaction) error { + err := kv.RunInNewTxn(ctx, alloc.store, true, func(ctx context.Context, txn kv.Transaction) error { + if allocatorStats != nil { + txn.SetOption(kv.CollectRuntimeStats, allocatorStats.SnapshotRuntimeStats) + } idAcc := alloc.getIDAccessor(txn) currentEnd, err1 := idAcc.Get() if err1 != nil { @@ -290,7 +308,7 @@ func (alloc *allocator) rebase4Unsigned(requiredBase uint64, allocIDs bool) erro return nil } -func (alloc *allocator) rebase4Signed(requiredBase int64, allocIDs bool) error { +func (alloc *allocator) rebase4Signed(ctx context.Context, requiredBase int64, allocIDs bool) error { // Satisfied by alloc.base, nothing to do. if requiredBase <= alloc.base { return nil @@ -300,9 +318,22 @@ func (alloc *allocator) rebase4Signed(requiredBase int64, allocIDs bool) error { alloc.base = requiredBase return nil } + + ctx, allocatorStats, commitDetail := getAllocatorStatsFromCtx(ctx) + if allocatorStats != nil { + allocatorStats.rebaseCount++ + defer func() { + if commitDetail != nil { + allocatorStats.mergeCommitDetail(*commitDetail) + } + }() + } var newBase, newEnd int64 startTime := time.Now() - err := kv.RunInNewTxn(context.Background(), alloc.store, true, func(ctx context.Context, txn kv.Transaction) error { + err := kv.RunInNewTxn(ctx, alloc.store, true, func(ctx context.Context, txn kv.Transaction) error { + if allocatorStats != nil { + txn.SetOption(kv.CollectRuntimeStats, allocatorStats.SnapshotRuntimeStats) + } idAcc := alloc.getIDAccessor(txn) currentEnd, err1 := idAcc.Get() if err1 != nil { @@ -379,13 +410,13 @@ func (alloc *allocator) rebase4Sequence(requiredBase int64) (int64, bool, error) // Rebase implements autoid.Allocator Rebase interface. // The requiredBase is the minimum base value after Rebase. // The real base may be greater than the required base. -func (alloc *allocator) Rebase(requiredBase int64, allocIDs bool) error { +func (alloc *allocator) Rebase(ctx context.Context, requiredBase int64, allocIDs bool) error { alloc.mu.Lock() defer alloc.mu.Unlock() if alloc.isUnsigned { - return alloc.rebase4Unsigned(uint64(requiredBase), allocIDs) + return alloc.rebase4Unsigned(ctx, uint64(requiredBase), allocIDs) } - return alloc.rebase4Signed(requiredBase, allocIDs) + return alloc.rebase4Signed(ctx, requiredBase, allocIDs) } // ForceRebase implements autoid.Allocator ForceRebase interface. @@ -695,7 +726,7 @@ func SeekToFirstAutoIDUnSigned(base, increment, offset uint64) uint64 { func (alloc *allocator) alloc4Signed(ctx context.Context, n uint64, increment, offset int64) (int64, int64, error) { // Check offset rebase if necessary. if offset-1 > alloc.base { - if err := alloc.rebase4Signed(offset-1, true); err != nil { + if err := alloc.rebase4Signed(ctx, offset-1, true); err != nil { return 0, 0, err } } @@ -716,12 +747,26 @@ func (alloc *allocator) alloc4Signed(ctx context.Context, n uint64, increment, o consumeDur := startTime.Sub(alloc.lastAllocTime) nextStep = NextStep(alloc.step, consumeDur) } + + ctx, allocatorStats, commitDetail := getAllocatorStatsFromCtx(ctx) + if allocatorStats != nil { + allocatorStats.allocCount++ + defer func() { + if commitDetail != nil { + allocatorStats.mergeCommitDetail(*commitDetail) + } + }() + } + err := kv.RunInNewTxn(ctx, alloc.store, true, func(ctx context.Context, txn kv.Transaction) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("alloc.alloc4Signed", opentracing.ChildOf(span.Context())) defer span1.Finish() opentracing.ContextWithSpan(ctx, span1) } + if allocatorStats != nil { + txn.SetOption(kv.CollectRuntimeStats, allocatorStats.SnapshotRuntimeStats) + } idAcc := alloc.getIDAccessor(txn) var err1 error @@ -770,7 +815,7 @@ func (alloc *allocator) alloc4Signed(ctx context.Context, n uint64, increment, o func (alloc *allocator) alloc4Unsigned(ctx context.Context, n uint64, increment, offset int64) (int64, int64, error) { // Check offset rebase if necessary. if uint64(offset-1) > uint64(alloc.base) { - if err := alloc.rebase4Unsigned(uint64(offset-1), true); err != nil { + if err := alloc.rebase4Unsigned(ctx, uint64(offset-1), true); err != nil { return 0, 0, err } } @@ -791,12 +836,27 @@ func (alloc *allocator) alloc4Unsigned(ctx context.Context, n uint64, increment, consumeDur := startTime.Sub(alloc.lastAllocTime) nextStep = NextStep(alloc.step, consumeDur) } + + ctx, allocatorStats, commitDetail := getAllocatorStatsFromCtx(ctx) + if allocatorStats != nil { + allocatorStats.allocCount++ + defer func() { + if commitDetail != nil { + allocatorStats.mergeCommitDetail(*commitDetail) + } + }() + } + err := kv.RunInNewTxn(ctx, alloc.store, true, func(ctx context.Context, txn kv.Transaction) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("alloc.alloc4Unsigned", opentracing.ChildOf(span.Context())) defer span1.Finish() opentracing.ContextWithSpan(ctx, span1) } + if allocatorStats != nil { + txn.SetOption(kv.CollectRuntimeStats, allocatorStats.SnapshotRuntimeStats) + } + idAcc := alloc.getIDAccessor(txn) var err1 error newBase, err1 = idAcc.Get() @@ -842,6 +902,17 @@ func (alloc *allocator) alloc4Unsigned(ctx context.Context, n uint64, increment, return min, alloc.base, nil } +func getAllocatorStatsFromCtx(ctx context.Context) (context.Context, *AllocatorRuntimeStats, **tikvutil.CommitDetails) { + var allocatorStats *AllocatorRuntimeStats + var commitDetail *tikvutil.CommitDetails + ctxValue := ctx.Value(AllocatorRuntimeStatsCtxKey) + if ctxValue != nil { + allocatorStats = ctxValue.(*AllocatorRuntimeStats) + ctx = context.WithValue(ctx, tikvutil.CommitDetailCtxKey, &commitDetail) + } + return ctx, allocatorStats, &commitDetail +} + // alloc4Sequence is used to alloc value for sequence, there are several aspects different from autoid logic. // 1: sequence allocation don't need check rebase. // 2: sequence allocation don't need auto step. @@ -1024,3 +1095,111 @@ func (l *ShardIDLayout) IncrementalBitsCapacity() uint64 { func (l *ShardIDLayout) IncrementalMask() int64 { return (1 << l.IncrementalBits) - 1 } + +type allocatorRuntimeStatsCtxKeyType struct{} + +// AllocatorRuntimeStatsCtxKey is the context key of allocator runtime stats. +var AllocatorRuntimeStatsCtxKey = allocatorRuntimeStatsCtxKeyType{} + +// AllocatorRuntimeStats is the execution stats of auto id allocator. +type AllocatorRuntimeStats struct { + *txnsnapshot.SnapshotRuntimeStats + *execdetails.RuntimeStatsWithCommit + allocCount int + rebaseCount int +} + +// NewAllocatorRuntimeStats return a new AllocatorRuntimeStats. +func NewAllocatorRuntimeStats() *AllocatorRuntimeStats { + return &AllocatorRuntimeStats{ + SnapshotRuntimeStats: &txnsnapshot.SnapshotRuntimeStats{}, + } +} + +func (e *AllocatorRuntimeStats) mergeCommitDetail(detail *tikvutil.CommitDetails) { + if detail == nil { + return + } + if e.RuntimeStatsWithCommit == nil { + e.RuntimeStatsWithCommit = &execdetails.RuntimeStatsWithCommit{} + } + e.RuntimeStatsWithCommit.MergeCommitDetails(detail) +} + +// String implements the RuntimeStats interface. +func (e *AllocatorRuntimeStats) String() string { + if e.allocCount == 0 && e.rebaseCount == 0 { + return "" + } + var buf bytes.Buffer + buf.WriteString("auto_id_allocator: {") + initialSize := buf.Len() + if e.allocCount > 0 { + buf.WriteString("alloc_cnt: ") + buf.WriteString(strconv.FormatInt(int64(e.allocCount), 10)) + } + if e.rebaseCount > 0 { + if buf.Len() > initialSize { + buf.WriteString(", ") + } + buf.WriteString("rebase_cnt: ") + buf.WriteString(strconv.FormatInt(int64(e.rebaseCount), 10)) + } + if e.SnapshotRuntimeStats != nil { + stats := e.SnapshotRuntimeStats.String() + if stats != "" { + if buf.Len() > initialSize { + buf.WriteString(", ") + } + buf.WriteString(e.SnapshotRuntimeStats.String()) + } + } + if e.RuntimeStatsWithCommit != nil { + stats := e.RuntimeStatsWithCommit.String() + if stats != "" { + if buf.Len() > initialSize { + buf.WriteString(", ") + } + buf.WriteString(stats) + } + } + buf.WriteString("}") + return buf.String() +} + +// Clone implements the RuntimeStats interface. +func (e *AllocatorRuntimeStats) Clone() *AllocatorRuntimeStats { + newRs := &AllocatorRuntimeStats{ + allocCount: e.allocCount, + rebaseCount: e.rebaseCount, + } + if e.SnapshotRuntimeStats != nil { + snapshotStats := e.SnapshotRuntimeStats.Clone() + newRs.SnapshotRuntimeStats = snapshotStats + } + if e.RuntimeStatsWithCommit != nil { + newRs.RuntimeStatsWithCommit = e.RuntimeStatsWithCommit.Clone().(*execdetails.RuntimeStatsWithCommit) + } + return newRs +} + +// Merge implements the RuntimeStats interface. +func (e *AllocatorRuntimeStats) Merge(other *AllocatorRuntimeStats) { + if other == nil { + return + } + if other.SnapshotRuntimeStats != nil { + if e.SnapshotRuntimeStats == nil { + e.SnapshotRuntimeStats = other.SnapshotRuntimeStats.Clone() + } else { + e.SnapshotRuntimeStats.Merge(other.SnapshotRuntimeStats) + } + } + if other.RuntimeStatsWithCommit != nil { + if e.RuntimeStatsWithCommit == nil { + e.RuntimeStatsWithCommit = other.RuntimeStatsWithCommit.Clone().(*execdetails.RuntimeStatsWithCommit) + } else { + e.RuntimeStatsWithCommit.Merge(other.RuntimeStatsWithCommit) + } + } +} diff --git a/meta/autoid/autoid_test.go b/meta/autoid/autoid_test.go index cc1e8459ddfff..1fd0a1c07d1b3 100644 --- a/meta/autoid/autoid_test.go +++ b/meta/autoid/autoid_test.go @@ -83,22 +83,22 @@ func TestSignedAutoid(t *testing.T) { require.Equal(t, autoid.GetStep()+1, globalAutoID) // rebase - err = alloc.Rebase(int64(1), true) + err = alloc.Rebase(context.Background(), int64(1), true) require.NoError(t, err) _, id, err = alloc.Alloc(ctx, 1, 1, 1) require.NoError(t, err) require.Equal(t, int64(3), id) - err = alloc.Rebase(int64(3), true) + err = alloc.Rebase(context.Background(), int64(3), true) require.NoError(t, err) _, id, err = alloc.Alloc(ctx, 1, 1, 1) require.NoError(t, err) require.Equal(t, int64(4), id) - err = alloc.Rebase(int64(10), true) + err = alloc.Rebase(context.Background(), int64(10), true) require.NoError(t, err) _, id, err = alloc.Alloc(ctx, 1, 1, 1) require.NoError(t, err) require.Equal(t, int64(11), id) - err = alloc.Rebase(int64(3010), true) + err = alloc.Rebase(context.Background(), int64(3010), true) require.NoError(t, err) _, id, err = alloc.Alloc(ctx, 1, 1, 1) require.NoError(t, err) @@ -112,7 +112,7 @@ func TestSignedAutoid(t *testing.T) { alloc = autoid.NewAllocator(store, 1, 2, false, autoid.RowIDAllocType) require.NotNil(t, alloc) - err = alloc.Rebase(int64(1), false) + err = alloc.Rebase(context.Background(), int64(1), false) require.NoError(t, err) _, id, err = alloc.Alloc(ctx, 1, 1, 1) require.NoError(t, err) @@ -120,27 +120,27 @@ func TestSignedAutoid(t *testing.T) { alloc = autoid.NewAllocator(store, 1, 3, false, autoid.RowIDAllocType) require.NotNil(t, alloc) - err = alloc.Rebase(int64(3210), false) + err = alloc.Rebase(context.Background(), int64(3210), false) require.NoError(t, err) alloc = autoid.NewAllocator(store, 1, 3, false, autoid.RowIDAllocType) require.NotNil(t, alloc) - err = alloc.Rebase(int64(3000), false) + err = alloc.Rebase(context.Background(), int64(3000), false) require.NoError(t, err) _, id, err = alloc.Alloc(ctx, 1, 1, 1) require.NoError(t, err) require.Equal(t, int64(3211), id) - err = alloc.Rebase(int64(6543), false) + err = alloc.Rebase(context.Background(), int64(6543), false) require.NoError(t, err) _, id, err = alloc.Alloc(ctx, 1, 1, 1) require.NoError(t, err) require.Equal(t, int64(6544), id) // Test the MaxInt64 is the upper bound of `alloc` function but not `rebase`. - err = alloc.Rebase(int64(math.MaxInt64-1), true) + err = alloc.Rebase(context.Background(), int64(math.MaxInt64-1), true) require.NoError(t, err) _, _, err = alloc.Alloc(ctx, 1, 1, 1) require.Error(t, err) - err = alloc.Rebase(int64(math.MaxInt64), true) + err = alloc.Rebase(context.Background(), int64(math.MaxInt64), true) require.NoError(t, err) // alloc N for signed @@ -169,7 +169,7 @@ func TestSignedAutoid(t *testing.T) { expected++ } - err = alloc.Rebase(int64(1000), false) + err = alloc.Rebase(context.Background(), int64(1000), false) require.NoError(t, err) min, max, err = alloc.Alloc(ctx, 3, 1, 1) require.NoError(t, err) @@ -179,7 +179,7 @@ func TestSignedAutoid(t *testing.T) { require.Equal(t, int64(1003), max) lastRemainOne := alloc.End() - err = alloc.Rebase(alloc.End()-2, false) + err = alloc.Rebase(context.Background(), alloc.End()-2, false) require.NoError(t, err) min, max, err = alloc.Alloc(ctx, 5, 1, 1) require.NoError(t, err) @@ -287,22 +287,22 @@ func TestUnsignedAutoid(t *testing.T) { require.Equal(t, autoid.GetStep()+1, globalAutoID) // rebase - err = alloc.Rebase(int64(1), true) + err = alloc.Rebase(context.Background(), int64(1), true) require.NoError(t, err) _, id, err = alloc.Alloc(ctx, 1, 1, 1) require.NoError(t, err) require.Equal(t, int64(3), id) - err = alloc.Rebase(int64(3), true) + err = alloc.Rebase(context.Background(), int64(3), true) require.NoError(t, err) _, id, err = alloc.Alloc(ctx, 1, 1, 1) require.NoError(t, err) require.Equal(t, int64(4), id) - err = alloc.Rebase(int64(10), true) + err = alloc.Rebase(context.Background(), int64(10), true) require.NoError(t, err) _, id, err = alloc.Alloc(ctx, 1, 1, 1) require.NoError(t, err) require.Equal(t, int64(11), id) - err = alloc.Rebase(int64(3010), true) + err = alloc.Rebase(context.Background(), int64(3010), true) require.NoError(t, err) _, id, err = alloc.Alloc(ctx, 1, 1, 1) require.NoError(t, err) @@ -316,7 +316,7 @@ func TestUnsignedAutoid(t *testing.T) { alloc = autoid.NewAllocator(store, 1, 2, true, autoid.RowIDAllocType) require.NotNil(t, alloc) - err = alloc.Rebase(int64(1), false) + err = alloc.Rebase(context.Background(), int64(1), false) require.NoError(t, err) _, id, err = alloc.Alloc(ctx, 1, 1, 1) require.NoError(t, err) @@ -324,16 +324,16 @@ func TestUnsignedAutoid(t *testing.T) { alloc = autoid.NewAllocator(store, 1, 3, true, autoid.RowIDAllocType) require.NotNil(t, alloc) - err = alloc.Rebase(int64(3210), false) + err = alloc.Rebase(context.Background(), int64(3210), false) require.NoError(t, err) alloc = autoid.NewAllocator(store, 1, 3, true, autoid.RowIDAllocType) require.NotNil(t, alloc) - err = alloc.Rebase(int64(3000), false) + err = alloc.Rebase(context.Background(), int64(3000), false) require.NoError(t, err) _, id, err = alloc.Alloc(ctx, 1, 1, 1) require.NoError(t, err) require.Equal(t, int64(3211), id) - err = alloc.Rebase(int64(6543), false) + err = alloc.Rebase(context.Background(), int64(6543), false) require.NoError(t, err) _, id, err = alloc.Alloc(ctx, 1, 1, 1) require.NoError(t, err) @@ -342,12 +342,12 @@ func TestUnsignedAutoid(t *testing.T) { // Test the MaxUint64 is the upper bound of `alloc` func but not `rebase`. var n uint64 = math.MaxUint64 - 1 un := int64(n) - err = alloc.Rebase(un, true) + err = alloc.Rebase(context.Background(), un, true) require.NoError(t, err) _, _, err = alloc.Alloc(ctx, 1, 1, 1) require.Error(t, err) un = int64(n + 1) - err = alloc.Rebase(un, true) + err = alloc.Rebase(context.Background(), un, true) require.NoError(t, err) // alloc N for unsigned @@ -363,7 +363,7 @@ func TestUnsignedAutoid(t *testing.T) { require.Equal(t, int64(1), min+1) require.Equal(t, int64(2), max) - err = alloc.Rebase(int64(500), true) + err = alloc.Rebase(context.Background(), int64(500), true) require.NoError(t, err) min, max, err = alloc.Alloc(ctx, 2, 1, 1) require.NoError(t, err) @@ -372,7 +372,7 @@ func TestUnsignedAutoid(t *testing.T) { require.Equal(t, int64(502), max) lastRemainOne := alloc.End() - err = alloc.Rebase(alloc.End()-2, false) + err = alloc.Rebase(context.Background(), alloc.End()-2, false) require.NoError(t, err) min, max, err = alloc.Alloc(ctx, 5, 1, 1) require.NoError(t, err) @@ -521,7 +521,7 @@ func TestRollbackAlloc(t *testing.T) { require.Equal(t, int64(0), alloc.Base()) require.Equal(t, int64(0), alloc.End()) - err = alloc.Rebase(100, true) + err = alloc.Rebase(context.Background(), 100, true) require.Error(t, err) require.Equal(t, int64(0), alloc.Base()) require.Equal(t, int64(0), alloc.End()) @@ -573,10 +573,10 @@ func TestAllocComputationIssue(t *testing.T) { require.NotNil(t, signedAlloc2) // the next valid two value must be 13 & 16, batch size = 6. - err = unsignedAlloc1.Rebase(10, false) + err = unsignedAlloc1.Rebase(context.Background(), 10, false) require.NoError(t, err) // the next valid two value must be 10 & 13, batch size = 6. - err = signedAlloc2.Rebase(7, false) + err = signedAlloc2.Rebase(context.Background(), 7, false) require.NoError(t, err) // Simulate the rest cache is not enough for next batch, assuming 10 & 13, batch size = 4. autoid.TestModifyBaseAndEndInjection(unsignedAlloc1, 9, 9) diff --git a/meta/autoid/memid.go b/meta/autoid/memid.go index 1a9af524959b2..21848c9d455d6 100644 --- a/meta/autoid/memid.go +++ b/meta/autoid/memid.go @@ -86,7 +86,7 @@ func (alloc *inMemoryAllocator) Alloc(ctx context.Context, n uint64, increment, // Rebase implements autoid.Allocator Rebase interface. // The requiredBase is the minimum base value after Rebase. // The real base may be greater than the required base. -func (alloc *inMemoryAllocator) Rebase(requiredBase int64, allocIDs bool) error { +func (alloc *inMemoryAllocator) Rebase(ctx context.Context, requiredBase int64, allocIDs bool) error { if alloc.isUnsigned { if uint64(requiredBase) > uint64(alloc.base) { alloc.base = requiredBase diff --git a/meta/autoid/memid_test.go b/meta/autoid/memid_test.go index 15dd436f98ed1..46f67170b0673 100644 --- a/meta/autoid/memid_test.go +++ b/meta/autoid/memid_test.go @@ -72,19 +72,19 @@ func TestInMemoryAlloc(t *testing.T) { require.Equal(t, int64(30), id) // rebase - err = alloc.Rebase(int64(40), true) + err = alloc.Rebase(context.Background(), int64(40), true) require.NoError(t, err) _, id, err = alloc.Alloc(ctx, 1, 1, 1) require.NoError(t, err) require.Equal(t, int64(41), id) - err = alloc.Rebase(int64(10), true) + err = alloc.Rebase(context.Background(), int64(10), true) require.NoError(t, err) _, id, err = alloc.Alloc(ctx, 1, 1, 1) require.NoError(t, err) require.Equal(t, int64(42), id) // maxInt64 - err = alloc.Rebase(int64(math.MaxInt64-2), true) + err = alloc.Rebase(context.Background(), int64(math.MaxInt64-2), true) require.NoError(t, err) _, id, err = alloc.Alloc(ctx, 1, 1, 1) require.NoError(t, err) @@ -98,7 +98,7 @@ func TestInMemoryAlloc(t *testing.T) { require.NotNil(t, alloc) var n uint64 = math.MaxUint64 - 2 - err = alloc.Rebase(int64(n), true) + err = alloc.Rebase(context.Background(), int64(n), true) require.NoError(t, err) _, id, err = alloc.Alloc(ctx, 1, 1, 1) require.NoError(t, err) diff --git a/planner/core/cache_test.go b/planner/core/cache_test.go index 3fb76763bd9ac..457b3e03ff1ca 100644 --- a/planner/core/cache_test.go +++ b/planner/core/cache_test.go @@ -15,31 +15,20 @@ package core import ( + "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/parser/mysql" - "github.com/pingcap/tidb/sessionctx" - "github.com/pingcap/tidb/util/testleak" + "github.com/stretchr/testify/require" ) -var _ = Suite(&testCacheSuite{}) - -type testCacheSuite struct { - ctx sessionctx.Context -} - -func (s *testCacheSuite) SetUpSuite(c *C) { +func TestCacheKey(t *testing.T) { + t.Parallel() ctx := MockContext() ctx.GetSessionVars().SnapshotTS = 0 ctx.GetSessionVars().SQLMode = mysql.ModeNone ctx.GetSessionVars().TimeZone = time.UTC ctx.GetSessionVars().ConnectionID = 0 - s.ctx = ctx -} - -func (s *testCacheSuite) TestCacheKey(c *C) { - defer testleak.AfterTest(c)() - key := NewPSTMTPlanCacheKey(s.ctx.GetSessionVars(), 1, 1) - c.Assert(key.Hash(), DeepEquals, []byte{0x74, 0x65, 0x73, 0x74, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x74, 0x69, 0x64, 0x62, 0x74, 0x69, 0x6b, 0x76, 0x74, 0x69, 0x66, 0x6c, 0x61, 0x73, 0x68, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) + key := NewPSTMTPlanCacheKey(ctx.GetSessionVars(), 1, 1) + require.Equal(t, []byte{0x74, 0x65, 0x73, 0x74, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x74, 0x69, 0x64, 0x62, 0x74, 0x69, 0x6b, 0x76, 0x74, 0x69, 0x66, 0x6c, 0x61, 0x73, 0x68, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, key.Hash()) } diff --git a/planner/core/integration_test.go b/planner/core/integration_test.go index 753215b11a9a7..5b55baa515bf4 100644 --- a/planner/core/integration_test.go +++ b/planner/core/integration_test.go @@ -2489,6 +2489,106 @@ func (s *testIntegrationSerialSuite) TestExplainAnalyzeDML(c *C) { checkExplain("BatchGet") } +func (s *testIntegrationSerialSuite) TestExplainAnalyzeDML2(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + + cases := []struct { + prepare string + sql string + planRegexp string + }{ + // Test for alloc auto ID. + { + sql: "insert into t () values ()", + planRegexp: ".*prepare.*total.*, auto_id_allocator.*alloc_cnt: 1, Get.*num_rpc.*total_time.*commit_txn.*prewrite.*get_commit_ts.*commit.*write_keys.*, insert.*", + }, + // Test for rebase ID. + { + sql: "insert into t (a) values (99000000000)", + planRegexp: ".*prepare.*total.*, auto_id_allocator.*rebase_cnt: 1, Get.*num_rpc.*total_time.*commit_txn.*prewrite.*get_commit_ts.*commit.*write_keys.*, insert.*", + }, + // Test for alloc auto ID and rebase ID. + { + sql: "insert into t (a) values (null), (99000000000)", + planRegexp: ".*prepare.*total.*, auto_id_allocator.*alloc_cnt: 1, rebase_cnt: 1, Get.*num_rpc.*total_time.*commit_txn.*prewrite.*get_commit_ts.*commit.*write_keys.*, insert.*", + }, + // Test for insert ignore. + { + sql: "insert ignore into t values (null,1), (2, 2), (99000000000, 3), (100000000000, 4)", + planRegexp: ".*prepare.*total.*, auto_id_allocator.*alloc_cnt: 1, rebase_cnt: 2, Get.*num_rpc.*total_time.*commit_txn.*count: 3, prewrite.*get_commit_ts.*commit.*write_keys.*, check_insert.*", + }, + // Test for insert on duplicate. + { + sql: "insert into t values (null,null), (1,1),(2,2) on duplicate key update a = a + 100000000000", + planRegexp: ".*prepare.*total.*, auto_id_allocator.*alloc_cnt: 1, rebase_cnt: 1, Get.*num_rpc.*total_time.*commit_txn.*count: 2, prewrite.*get_commit_ts.*commit.*write_keys.*, check_insert.*", + }, + // Test for replace with alloc ID. + { + sql: "replace into t () values ()", + planRegexp: ".*auto_id_allocator.*alloc_cnt: 1, Get.*num_rpc.*total_time.*commit_txn.*prewrite.*get_commit_ts.*commit.*write_keys.*", + }, + // Test for replace with alloc ID and rebase ID. + { + sql: "replace into t (a) values (null), (99000000000)", + planRegexp: ".*auto_id_allocator.*alloc_cnt: 1, rebase_cnt: 1, Get.*num_rpc.*total_time.*commit_txn.*prewrite.*get_commit_ts.*commit.*write_keys.*", + }, + // Test for update with rebase ID. + { + prepare: "insert into t values (1,1),(2,2)", + sql: "update t set a=a*100000000000", + planRegexp: ".*auto_id_allocator.*rebase_cnt: 2, Get.*num_rpc.*total_time.*commit_txn.*prewrite.*get_commit_ts.*commit.*write_keys.*", + }, + } + + for _, ca := range cases { + for i := 0; i < 3; i++ { + tk.MustExec("drop table if exists t") + switch i { + case 0: + tk.MustExec("create table t (a bigint auto_increment, b int, primary key (a));") + case 1: + tk.MustExec("create table t (a bigint unsigned auto_increment, b int, primary key (a));") + case 2: + if strings.Contains(ca.sql, "on duplicate key") { + continue + } + tk.MustExec("create table t (a bigint primary key auto_random(5), b int);") + tk.MustExec("set @@allow_auto_random_explicit_insert=1;") + default: + panic("should never happen") + } + if ca.prepare != "" { + tk.MustExec(ca.prepare) + } + res := tk.MustQuery("explain analyze " + ca.sql) + resBuff := bytes.NewBufferString("") + for _, row := range res.Rows() { + fmt.Fprintf(resBuff, "%s\t", row) + } + explain := resBuff.String() + c.Assert(explain, Matches, ca.planRegexp, Commentf("idx: %v,sql: %v", i, ca.sql)) + } + } + + // Test for table without auto id. + for _, ca := range cases { + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a bigint, b int);") + tk.MustExec("insert into t () values ()") + if ca.prepare != "" { + tk.MustExec(ca.prepare) + } + res := tk.MustQuery("explain analyze " + ca.sql) + resBuff := bytes.NewBufferString("") + for _, row := range res.Rows() { + fmt.Fprintf(resBuff, "%s\t", row) + } + explain := resBuff.String() + c.Assert(strings.Contains(explain, "auto_id_allocator"), IsFalse, Commentf("sql: %v, explain: %v", ca.sql, explain)) + } +} + func (s *testIntegrationSuite) TestPartitionExplain(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") @@ -3237,7 +3337,7 @@ func (s *testIntegrationSerialSuite) TestPushDownProjectionForTiFlash(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists t") - tk.MustExec("create table t (id int, value decimal(6,3))") + tk.MustExec("create table t (id int, value decimal(6,3), name char(128))") tk.MustExec("analyze table t") tk.MustExec("set session tidb_allow_mpp=OFF") @@ -3277,7 +3377,7 @@ func (s *testIntegrationSerialSuite) TestPushDownProjectionForMPP(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists t") - tk.MustExec("create table t (id int, value decimal(6,3))") + tk.MustExec("create table t (id int, value decimal(6,3), name char(128))") tk.MustExec("analyze table t") // Create virtual tiflash replica info. diff --git a/planner/core/logical_plan_test.go b/planner/core/logical_plan_test.go index 2ee4991f28609..eaa731ba04fdd 100644 --- a/planner/core/logical_plan_test.go +++ b/planner/core/logical_plan_test.go @@ -1072,6 +1072,7 @@ func (s *testPlanSuite) TestVisitInfo(c *C) { {mysql.GrantPriv, "test", "", "", nil, false, "", false}, {mysql.ReferencesPriv, "test", "", "", nil, false, "", false}, {mysql.LockTablesPriv, "test", "", "", nil, false, "", false}, + {mysql.CreateTMPTablePriv, "test", "", "", nil, false, "", false}, {mysql.AlterPriv, "test", "", "", nil, false, "", false}, {mysql.ExecutePriv, "test", "", "", nil, false, "", false}, {mysql.IndexPriv, "test", "", "", nil, false, "", false}, @@ -1142,6 +1143,7 @@ func (s *testPlanSuite) TestVisitInfo(c *C) { {mysql.GrantPriv, "test", "", "", nil, false, "", false}, {mysql.ReferencesPriv, "test", "", "", nil, false, "", false}, {mysql.LockTablesPriv, "test", "", "", nil, false, "", false}, + {mysql.CreateTMPTablePriv, "test", "", "", nil, false, "", false}, {mysql.AlterPriv, "test", "", "", nil, false, "", false}, {mysql.ExecutePriv, "test", "", "", nil, false, "", false}, {mysql.IndexPriv, "test", "", "", nil, false, "", false}, diff --git a/planner/core/optimizer.go b/planner/core/optimizer.go index 001add7c5021b..b372b6c92f946 100644 --- a/planner/core/optimizer.go +++ b/planner/core/optimizer.go @@ -199,9 +199,40 @@ func postOptimize(sctx sessionctx.Context, plan PhysicalPlan) PhysicalPlan { mergeContinuousSelections(plan) plan = eliminateUnionScanAndLock(sctx, plan) plan = enableParallelApply(sctx, plan) + checkPlanCacheable(sctx, plan) return plan } +// checkPlanCacheable used to check whether a plan can be cached. Plans that +// meet the following characteristics cannot be cached: +// 1. Use the TiFlash engine. +// Todo: make more careful check here. +func checkPlanCacheable(sctx sessionctx.Context, plan PhysicalPlan) { + if sctx.GetSessionVars().StmtCtx.UseCache && useTiFlash(plan) { + sctx.GetSessionVars().StmtCtx.MaybeOverOptimized4PlanCache = true + } +} + +// useTiFlash used to check whether the plan use the TiFlash engine. +func useTiFlash(p PhysicalPlan) bool { + switch x := p.(type) { + case *PhysicalTableReader: + switch x.StoreType { + case kv.TiFlash: + return true + default: + return false + } + default: + if len(p.Children()) > 0 { + for _, plan := range p.Children() { + return useTiFlash(plan) + } + } + } + return false +} + func enableParallelApply(sctx sessionctx.Context, plan PhysicalPlan) PhysicalPlan { if !sctx.GetSessionVars().EnableParallelApply { return plan diff --git a/planner/core/plan_to_pb_test.go b/planner/core/plan_to_pb_test.go index b61c9d0b73f23..3cbe142681c52 100644 --- a/planner/core/plan_to_pb_test.go +++ b/planner/core/plan_to_pb_test.go @@ -15,22 +15,19 @@ package core import ( - . "github.com/pingcap/check" + "testing" + "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/collate" - "github.com/pingcap/tidb/util/testleak" "github.com/pingcap/tipb/go-tipb" + "github.com/stretchr/testify/require" ) -var _ = SerialSuites(&testDistsqlSuite{}) - -type testDistsqlSuite struct{} - -func (s *testDistsqlSuite) TestColumnToProto(c *C) { - defer testleak.AfterTest(c)() +func TestColumnToProto(t *testing.T) { + t.Parallel() // Make sure the Flag is set in tipb.ColumnInfo tp := types.NewFieldType(mysql.TypeLong) tp.Flag = 10 @@ -40,16 +37,16 @@ func (s *testDistsqlSuite) TestColumnToProto(c *C) { } pc := util.ColumnToProto(col) expect := &tipb.ColumnInfo{ColumnId: 0, Tp: 3, Collation: 83, ColumnLen: -1, Decimal: -1, Flag: 10, Elems: []string(nil), DefaultVal: []uint8(nil), PkHandle: false, XXX_unrecognized: []uint8(nil)} - c.Assert(pc, DeepEquals, expect) + require.Equal(t, expect, pc) cols := []*model.ColumnInfo{col, col} pcs := util.ColumnsToProto(cols, false) for _, v := range pcs { - c.Assert(v.GetFlag(), Equals, int32(10)) + require.Equal(t, int32(10), v.GetFlag()) } pcs = util.ColumnsToProto(cols, true) for _, v := range pcs { - c.Assert(v.GetFlag(), Equals, int32(10)) + require.Equal(t, int32(10), v.GetFlag()) } // Make sure the collation ID is successfully set. @@ -60,20 +57,20 @@ func (s *testDistsqlSuite) TestColumnToProto(c *C) { FieldType: *tp, } pc = util.ColumnToProto(col1) - c.Assert(pc.Collation, Equals, int32(8)) + require.Equal(t, int32(8), pc.Collation) collate.SetNewCollationEnabledForTest(true) defer collate.SetNewCollationEnabledForTest(false) pc = util.ColumnToProto(col) expect = &tipb.ColumnInfo{ColumnId: 0, Tp: 3, Collation: -83, ColumnLen: -1, Decimal: -1, Flag: 10, Elems: []string(nil), DefaultVal: []uint8(nil), PkHandle: false, XXX_unrecognized: []uint8(nil)} - c.Assert(pc, DeepEquals, expect) + require.Equal(t, expect, pc) pcs = util.ColumnsToProto(cols, true) for _, v := range pcs { - c.Assert(v.Collation, Equals, int32(-83)) + require.Equal(t, int32(-83), v.Collation) } pc = util.ColumnToProto(col1) - c.Assert(pc.Collation, Equals, int32(-8)) + require.Equal(t, int32(-8), pc.Collation) tp = types.NewFieldType(mysql.TypeEnum) tp.Flag = 10 @@ -82,5 +79,5 @@ func (s *testDistsqlSuite) TestColumnToProto(c *C) { FieldType: *tp, } pc = util.ColumnToProto(col2) - c.Assert(len(pc.Elems), Equals, 2) + require.Len(t, pc.Elems, 2) } diff --git a/planner/core/testdata/enforce_mpp_suite_out.json b/planner/core/testdata/enforce_mpp_suite_out.json index 97cde0ecbde67..428be48f548d6 100644 --- a/planner/core/testdata/enforce_mpp_suite_out.json +++ b/planner/core/testdata/enforce_mpp_suite_out.json @@ -325,11 +325,11 @@ " └─TableFullScan_10 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo" ], "Warn": [ - "Scalar function 'md5'(signature: MD5) can not be pushed to tiflash", + "Scalar function 'md5'(signature: MD5, return type: var_string(32)) can not be pushed to tiflash", "Aggregation can not be pushed to tiflash because groupByItems contain unsupported exprs", - "Scalar function 'md5'(signature: MD5) can not be pushed to tiflash", + "Scalar function 'md5'(signature: MD5, return type: var_string(32)) can not be pushed to tiflash", "Aggregation can not be pushed to tiflash because groupByItems contain unsupported exprs", - "Scalar function 'md5'(signature: MD5) can not be pushed to tiflash", + "Scalar function 'md5'(signature: MD5, return type: var_string(32)) can not be pushed to tiflash", "Aggregation can not be pushed to tiflash because groupByItems contain unsupported exprs" ] }, diff --git a/planner/core/testdata/integration_serial_suite_in.json b/planner/core/testdata/integration_serial_suite_in.json index 13661972b9d4f..cae3db289d566 100644 --- a/planner/core/testdata/integration_serial_suite_in.json +++ b/planner/core/testdata/integration_serial_suite_in.json @@ -3,7 +3,7 @@ "name": "TestSelPushDownTiFlash", "cases": [ "explain format = 'brief' select * from t where t.a > 1 and t.b = \"flash\" or t.a + 3 * t.a = 5", - "explain format = 'brief' select * from t where cast(t.a as float) + 3 = 5.1", + "explain format = 'brief' select * from t where cast(t.a as double) + 3 = 5.1", "explain format = 'brief' select * from t where b > 'a' order by convert(b, unsigned) limit 2", "explain format = 'brief' select * from t where b > 'a' order by b limit 2" ] @@ -219,7 +219,8 @@ "desc format = 'brief' select * from t right join (select id-2 as b from t) A on A.b=t.id", "desc format = 'brief' select A.b, B.b from (select id-2 as b from t) B join (select id-2 as b from t) A on A.b=B.b", "desc format = 'brief' select A.id from t as A where exists (select 1 from t where t.id=A.id)", - "desc format = 'brief' select A.id from t as A where not exists (select 1 from t where t.id=A.id)" + "desc format = 'brief' select A.id from t as A where not exists (select 1 from t where t.id=A.id)", + "desc format = 'brief' SELECT FROM_UNIXTIME(name,'%Y-%m-%d') FROM t;" ] }, { @@ -238,7 +239,8 @@ "desc format = 'brief' select A.b, B.b from (select id-2 as b from t) B join (select id-2 as b from t) A on A.b=B.b", "desc format = 'brief' select id from t as A where exists (select 1 from t where t.id=A.id)", "desc format = 'brief' select id from t as A where not exists (select 1 from t where t.id=A.id)", - "desc format = 'brief' select b*2, id from (select avg(value+2) as b, id from t group by id) C order by id" + "desc format = 'brief' select b*2, id from (select avg(value+2) as b, id from t group by id) C order by id", + "desc format = 'brief' SELECT FROM_UNIXTIME(name,'%Y-%m-%d') FROM t;" ] }, { diff --git a/planner/core/testdata/integration_serial_suite_out.json b/planner/core/testdata/integration_serial_suite_out.json index a5ec475bb9005..20d3a375f1024 100644 --- a/planner/core/testdata/integration_serial_suite_out.json +++ b/planner/core/testdata/integration_serial_suite_out.json @@ -11,10 +11,10 @@ ] }, { - "SQL": "explain format = 'brief' select * from t where cast(t.a as float) + 3 = 5.1", + "SQL": "explain format = 'brief' select * from t where cast(t.a as double) + 3 = 5.1", "Plan": [ "TableReader 8000.00 root data:Selection", - "└─Selection 8000.00 cop[tiflash] eq(plus(cast(test.t.a, float BINARY), 3), 5.1)", + "└─Selection 8000.00 cop[tiflash] eq(plus(cast(test.t.a, double BINARY), 3), 5.1)", " └─TableFullScan 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo" ] }, @@ -1578,58 +1578,58 @@ { "SQL": "desc format = 'brief' select /*+ hash_agg()*/ count(b) from (select id + 1 as b from t)A", "Plan": [ - "HashAgg 1.00 root funcs:count(Column#7)->Column#5", + "HashAgg 1.00 root funcs:count(Column#8)->Column#6", "└─TableReader 1.00 root data:HashAgg", - " └─HashAgg 1.00 batchCop[tiflash] funcs:count(Column#9)->Column#7", - " └─Projection 10000.00 batchCop[tiflash] plus(test.t.id, 1)->Column#9", + " └─HashAgg 1.00 batchCop[tiflash] funcs:count(Column#10)->Column#8", + " └─Projection 10000.00 batchCop[tiflash] plus(test.t.id, 1)->Column#10", " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ] }, { "SQL": "desc format = 'brief' select /*+ hash_agg()*/ count(*) from (select id + 1 as b from t)A", "Plan": [ - "HashAgg 1.00 root funcs:count(Column#6)->Column#5", + "HashAgg 1.00 root funcs:count(Column#7)->Column#6", "└─TableReader 1.00 root data:HashAgg", - " └─HashAgg 1.00 batchCop[tiflash] funcs:count(1)->Column#6", + " └─HashAgg 1.00 batchCop[tiflash] funcs:count(1)->Column#7", " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ] }, { "SQL": "desc format = 'brief' select /*+ hash_agg()*/ sum(b) from (select id + 1 as b from t)A", "Plan": [ - "HashAgg 1.00 root funcs:sum(Column#7)->Column#5", + "HashAgg 1.00 root funcs:sum(Column#8)->Column#6", "└─TableReader 1.00 root data:HashAgg", - " └─HashAgg 1.00 batchCop[tiflash] funcs:sum(Column#9)->Column#7", - " └─Projection 10000.00 batchCop[tiflash] cast(plus(test.t.id, 1), decimal(41,0) BINARY)->Column#9", + " └─HashAgg 1.00 batchCop[tiflash] funcs:sum(Column#10)->Column#8", + " └─Projection 10000.00 batchCop[tiflash] cast(plus(test.t.id, 1), decimal(41,0) BINARY)->Column#10", " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ] }, { "SQL": "desc format = 'brief' select /*+ stream_agg()*/ count(b) from (select id + 1 as b from t)A", "Plan": [ - "StreamAgg 1.00 root funcs:count(Column#7)->Column#5", + "StreamAgg 1.00 root funcs:count(Column#8)->Column#6", "└─TableReader 1.00 root data:StreamAgg", - " └─StreamAgg 1.00 batchCop[tiflash] funcs:count(Column#9)->Column#7", - " └─Projection 10000.00 batchCop[tiflash] plus(test.t.id, 1)->Column#9", + " └─StreamAgg 1.00 batchCop[tiflash] funcs:count(Column#10)->Column#8", + " └─Projection 10000.00 batchCop[tiflash] plus(test.t.id, 1)->Column#10", " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ] }, { "SQL": "desc format = 'brief' select /*+ stream_agg()*/ count(*) from (select id + 1 as b from t)A", "Plan": [ - "StreamAgg 1.00 root funcs:count(Column#6)->Column#5", + "StreamAgg 1.00 root funcs:count(Column#7)->Column#6", "└─TableReader 1.00 root data:StreamAgg", - " └─StreamAgg 1.00 batchCop[tiflash] funcs:count(1)->Column#6", + " └─StreamAgg 1.00 batchCop[tiflash] funcs:count(1)->Column#7", " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ] }, { "SQL": "desc format = 'brief' select /*+ stream_agg()*/ sum(b) from (select id + 1 as b from t)A", "Plan": [ - "StreamAgg 1.00 root funcs:sum(Column#7)->Column#5", + "StreamAgg 1.00 root funcs:sum(Column#8)->Column#6", "└─TableReader 1.00 root data:StreamAgg", - " └─StreamAgg 1.00 batchCop[tiflash] funcs:sum(Column#9)->Column#7", - " └─Projection 10000.00 batchCop[tiflash] cast(plus(test.t.id, 1), decimal(41,0) BINARY)->Column#9", + " └─StreamAgg 1.00 batchCop[tiflash] funcs:sum(Column#10)->Column#8", + " └─Projection 10000.00 batchCop[tiflash] cast(plus(test.t.id, 1), decimal(41,0) BINARY)->Column#10", " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ] }, @@ -1637,11 +1637,11 @@ "SQL": "desc format = 'brief' select * from (select id-2 as b from t) B join (select id-2 as b from t) A on A.b=B.b", "Plan": [ "TableReader 10000.00 root data:HashJoin", - "└─HashJoin 10000.00 cop[tiflash] inner join, equal:[eq(Column#4, Column#8)]", - " ├─Projection(Build) 8000.00 cop[tiflash] minus(test.t.id, 2)->Column#4", + "└─HashJoin 10000.00 cop[tiflash] inner join, equal:[eq(Column#5, Column#10)]", + " ├─Projection(Build) 8000.00 cop[tiflash] minus(test.t.id, 2)->Column#5", " │ └─Selection 8000.00 cop[tiflash] not(isnull(minus(test.t.id, 2)))", " │ └─TableFullScan 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo, global read", - " └─Projection(Probe) 8000.00 cop[tiflash] minus(test.t.id, 2)->Column#8", + " └─Projection(Probe) 8000.00 cop[tiflash] minus(test.t.id, 2)->Column#10", " └─Selection 8000.00 cop[tiflash] not(isnull(minus(test.t.id, 2)))", " └─TableFullScan 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo" ] @@ -1650,8 +1650,8 @@ "SQL": "desc format = 'brief' select * from t join (select id-2 as b from t) A on A.b=t.id", "Plan": [ "TableReader 10000.00 root data:HashJoin", - "└─HashJoin 10000.00 cop[tiflash] inner join, equal:[eq(test.t.id, Column#7)]", - " ├─Projection(Build) 8000.00 cop[tiflash] minus(test.t.id, 2)->Column#7", + "└─HashJoin 10000.00 cop[tiflash] inner join, equal:[eq(test.t.id, Column#9)]", + " ├─Projection(Build) 8000.00 cop[tiflash] minus(test.t.id, 2)->Column#9", " │ └─Selection 8000.00 cop[tiflash] not(isnull(minus(test.t.id, 2)))", " │ └─TableFullScan 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo, global read", " └─Selection(Probe) 9990.00 cop[tiflash] not(isnull(test.t.id))", @@ -1662,8 +1662,8 @@ "SQL": "desc format = 'brief' select * from t left join (select id-2 as b from t) A on A.b=t.id", "Plan": [ "TableReader 10000.00 root data:HashJoin", - "└─HashJoin 10000.00 cop[tiflash] left outer join, equal:[eq(test.t.id, Column#7)]", - " ├─Projection(Build) 8000.00 cop[tiflash] minus(test.t.id, 2)->Column#7", + "└─HashJoin 10000.00 cop[tiflash] left outer join, equal:[eq(test.t.id, Column#9)]", + " ├─Projection(Build) 8000.00 cop[tiflash] minus(test.t.id, 2)->Column#9", " │ └─Selection 8000.00 cop[tiflash] not(isnull(minus(test.t.id, 2)))", " │ └─TableFullScan 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo, global read", " └─TableFullScan(Probe) 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo" @@ -1673,23 +1673,23 @@ "SQL": "desc format = 'brief' select * from t right join (select id-2 as b from t) A on A.b=t.id", "Plan": [ "TableReader 12487.50 root data:HashJoin", - "└─HashJoin 12487.50 cop[tiflash] right outer join, equal:[eq(test.t.id, Column#7)]", + "└─HashJoin 12487.50 cop[tiflash] right outer join, equal:[eq(test.t.id, Column#9)]", " ├─Selection(Build) 9990.00 cop[tiflash] not(isnull(test.t.id))", " │ └─TableFullScan 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo, global read", - " └─Projection(Probe) 10000.00 cop[tiflash] minus(test.t.id, 2)->Column#7", + " └─Projection(Probe) 10000.00 cop[tiflash] minus(test.t.id, 2)->Column#9", " └─TableFullScan 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo" ] }, { "SQL": "desc format = 'brief' select A.b, B.b from (select id-2 as b from t) B join (select id-2 as b from t) A on A.b=B.b", "Plan": [ - "Projection 10000.00 root Column#8, Column#4", + "Projection 10000.00 root Column#10, Column#5", "└─TableReader 10000.00 root data:HashJoin", - " └─HashJoin 10000.00 cop[tiflash] inner join, equal:[eq(Column#4, Column#8)]", - " ├─Projection(Build) 8000.00 cop[tiflash] minus(test.t.id, 2)->Column#4", + " └─HashJoin 10000.00 cop[tiflash] inner join, equal:[eq(Column#5, Column#10)]", + " ├─Projection(Build) 8000.00 cop[tiflash] minus(test.t.id, 2)->Column#5", " │ └─Selection 8000.00 cop[tiflash] not(isnull(minus(test.t.id, 2)))", " │ └─TableFullScan 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo, global read", - " └─Projection(Probe) 8000.00 cop[tiflash] minus(test.t.id, 2)->Column#8", + " └─Projection(Probe) 8000.00 cop[tiflash] minus(test.t.id, 2)->Column#10", " └─Selection 8000.00 cop[tiflash] not(isnull(minus(test.t.id, 2)))", " └─TableFullScan 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo" ] @@ -1713,6 +1713,14 @@ " ├─TableFullScan(Build) 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo, global read", " └─TableFullScan(Probe) 10000.00 cop[tiflash] table:A keep order:false, stats:pseudo" ] + }, + { + "SQL": "desc format = 'brief' SELECT FROM_UNIXTIME(name,'%Y-%m-%d') FROM t;", + "Plan": [ + "Projection 10000.00 root from_unixtime(cast(test.t.name, decimal(65,0) BINARY), %Y-%m-%d)->Column#5", + "└─TableReader 10000.00 root data:TableFullScan", + " └─TableFullScan 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo" + ] } ] }, @@ -1722,64 +1730,64 @@ { "SQL": "desc format = 'brief' select /*+ hash_agg()*/ count(b) from (select id + 1 as b from t)A", "Plan": [ - "HashAgg 1.00 root funcs:count(Column#8)->Column#5", + "HashAgg 1.00 root funcs:count(Column#9)->Column#6", "└─TableReader 1.00 root data:ExchangeSender", " └─ExchangeSender 1.00 batchCop[tiflash] ExchangeType: PassThrough", - " └─HashAgg 1.00 batchCop[tiflash] funcs:count(Column#10)->Column#8", - " └─Projection 10000.00 batchCop[tiflash] plus(test.t.id, 1)->Column#10", + " └─HashAgg 1.00 batchCop[tiflash] funcs:count(Column#11)->Column#9", + " └─Projection 10000.00 batchCop[tiflash] plus(test.t.id, 1)->Column#11", " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ] }, { "SQL": "desc format = 'brief' select /*+ hash_agg()*/ count(*) from (select id + 1 as b from t)A", "Plan": [ - "HashAgg 1.00 root funcs:count(Column#7)->Column#5", + "HashAgg 1.00 root funcs:count(Column#8)->Column#6", "└─TableReader 1.00 root data:ExchangeSender", " └─ExchangeSender 1.00 batchCop[tiflash] ExchangeType: PassThrough", - " └─HashAgg 1.00 batchCop[tiflash] funcs:count(1)->Column#7", + " └─HashAgg 1.00 batchCop[tiflash] funcs:count(1)->Column#8", " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ] }, { "SQL": "desc format = 'brief' select /*+ hash_agg()*/ sum(b) from (select id + 1 as b from t)A", "Plan": [ - "HashAgg 1.00 root funcs:sum(Column#8)->Column#5", + "HashAgg 1.00 root funcs:sum(Column#9)->Column#6", "└─TableReader 1.00 root data:ExchangeSender", " └─ExchangeSender 1.00 batchCop[tiflash] ExchangeType: PassThrough", - " └─HashAgg 1.00 batchCop[tiflash] funcs:sum(Column#10)->Column#8", - " └─Projection 10000.00 batchCop[tiflash] cast(plus(test.t.id, 1), decimal(41,0) BINARY)->Column#10", + " └─HashAgg 1.00 batchCop[tiflash] funcs:sum(Column#11)->Column#9", + " └─Projection 10000.00 batchCop[tiflash] cast(plus(test.t.id, 1), decimal(41,0) BINARY)->Column#11", " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ] }, { "SQL": "desc format = 'brief' select /*+ stream_agg()*/ count(b) from (select id + 1 as b from t)A", "Plan": [ - "HashAgg 1.00 root funcs:count(Column#9)->Column#5", + "HashAgg 1.00 root funcs:count(Column#10)->Column#6", "└─TableReader 1.00 root data:ExchangeSender", " └─ExchangeSender 1.00 batchCop[tiflash] ExchangeType: PassThrough", - " └─HashAgg 1.00 batchCop[tiflash] funcs:count(Column#10)->Column#9", - " └─Projection 10000.00 batchCop[tiflash] plus(test.t.id, 1)->Column#10", + " └─HashAgg 1.00 batchCop[tiflash] funcs:count(Column#11)->Column#10", + " └─Projection 10000.00 batchCop[tiflash] plus(test.t.id, 1)->Column#11", " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ] }, { "SQL": "desc format = 'brief' select /*+ stream_agg()*/ count(*) from (select id + 1 as b from t)A", "Plan": [ - "HashAgg 1.00 root funcs:count(Column#8)->Column#5", + "HashAgg 1.00 root funcs:count(Column#9)->Column#6", "└─TableReader 1.00 root data:ExchangeSender", " └─ExchangeSender 1.00 batchCop[tiflash] ExchangeType: PassThrough", - " └─HashAgg 1.00 batchCop[tiflash] funcs:count(1)->Column#8", + " └─HashAgg 1.00 batchCop[tiflash] funcs:count(1)->Column#9", " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ] }, { "SQL": "desc format = 'brief' select /*+ stream_agg()*/ sum(b) from (select id + 1 as b from t)A", "Plan": [ - "HashAgg 1.00 root funcs:sum(Column#9)->Column#5", + "HashAgg 1.00 root funcs:sum(Column#10)->Column#6", "└─TableReader 1.00 root data:ExchangeSender", " └─ExchangeSender 1.00 batchCop[tiflash] ExchangeType: PassThrough", - " └─HashAgg 1.00 batchCop[tiflash] funcs:sum(Column#10)->Column#9", - " └─Projection 10000.00 batchCop[tiflash] cast(plus(test.t.id, 1), decimal(41,0) BINARY)->Column#10", + " └─HashAgg 1.00 batchCop[tiflash] funcs:sum(Column#11)->Column#10", + " └─Projection 10000.00 batchCop[tiflash] cast(plus(test.t.id, 1), decimal(41,0) BINARY)->Column#11", " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ] }, @@ -1788,14 +1796,14 @@ "Plan": [ "TableReader 10000.00 root data:ExchangeSender", "└─ExchangeSender 10000.00 cop[tiflash] ExchangeType: PassThrough", - " └─Projection 10000.00 cop[tiflash] plus(Column#4, Column#8)->Column#9", - " └─HashJoin 10000.00 cop[tiflash] inner join, equal:[eq(Column#4, Column#8)]", + " └─Projection 10000.00 cop[tiflash] plus(Column#5, Column#10)->Column#11", + " └─HashJoin 10000.00 cop[tiflash] inner join, equal:[eq(Column#5, Column#10)]", " ├─ExchangeReceiver(Build) 8000.00 cop[tiflash] ", " │ └─ExchangeSender 8000.00 cop[tiflash] ExchangeType: Broadcast", - " │ └─Projection 8000.00 cop[tiflash] minus(test.t.id, 2)->Column#4", + " │ └─Projection 8000.00 cop[tiflash] minus(test.t.id, 2)->Column#5", " │ └─Selection 8000.00 cop[tiflash] not(isnull(minus(test.t.id, 2)))", " │ └─TableFullScan 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo", - " └─Projection(Probe) 8000.00 cop[tiflash] minus(test.t.id, 2)->Column#8", + " └─Projection(Probe) 8000.00 cop[tiflash] minus(test.t.id, 2)->Column#10", " └─Selection 8000.00 cop[tiflash] not(isnull(minus(test.t.id, 2)))", " └─TableFullScan 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo" ] @@ -1805,10 +1813,10 @@ "Plan": [ "TableReader 10000.00 root data:ExchangeSender", "└─ExchangeSender 10000.00 cop[tiflash] ExchangeType: PassThrough", - " └─HashJoin 10000.00 cop[tiflash] inner join, equal:[eq(test.t.id, Column#7)]", + " └─HashJoin 10000.00 cop[tiflash] inner join, equal:[eq(test.t.id, Column#9)]", " ├─ExchangeReceiver(Build) 8000.00 cop[tiflash] ", " │ └─ExchangeSender 8000.00 cop[tiflash] ExchangeType: Broadcast", - " │ └─Projection 8000.00 cop[tiflash] minus(test.t.id, 2)->Column#7", + " │ └─Projection 8000.00 cop[tiflash] minus(test.t.id, 2)->Column#9", " │ └─Selection 8000.00 cop[tiflash] not(isnull(minus(test.t.id, 2)))", " │ └─TableFullScan 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo", " └─Selection(Probe) 9990.00 cop[tiflash] not(isnull(test.t.id))", @@ -1820,10 +1828,10 @@ "Plan": [ "TableReader 10000.00 root data:ExchangeSender", "└─ExchangeSender 10000.00 cop[tiflash] ExchangeType: PassThrough", - " └─HashJoin 10000.00 cop[tiflash] left outer join, equal:[eq(test.t.id, Column#7)]", + " └─HashJoin 10000.00 cop[tiflash] left outer join, equal:[eq(test.t.id, Column#9)]", " ├─ExchangeReceiver(Build) 8000.00 cop[tiflash] ", " │ └─ExchangeSender 8000.00 cop[tiflash] ExchangeType: Broadcast", - " │ └─Projection 8000.00 cop[tiflash] minus(test.t.id, 2)->Column#7", + " │ └─Projection 8000.00 cop[tiflash] minus(test.t.id, 2)->Column#9", " │ └─Selection 8000.00 cop[tiflash] not(isnull(minus(test.t.id, 2)))", " │ └─TableFullScan 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo", " └─TableFullScan(Probe) 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo" @@ -1834,12 +1842,12 @@ "Plan": [ "TableReader 12487.50 root data:ExchangeSender", "└─ExchangeSender 12487.50 cop[tiflash] ExchangeType: PassThrough", - " └─HashJoin 12487.50 cop[tiflash] right outer join, equal:[eq(test.t.id, Column#7)]", + " └─HashJoin 12487.50 cop[tiflash] right outer join, equal:[eq(test.t.id, Column#9)]", " ├─ExchangeReceiver(Build) 9990.00 cop[tiflash] ", " │ └─ExchangeSender 9990.00 cop[tiflash] ExchangeType: Broadcast", " │ └─Selection 9990.00 cop[tiflash] not(isnull(test.t.id))", " │ └─TableFullScan 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo", - " └─Projection(Probe) 10000.00 cop[tiflash] minus(test.t.id, 2)->Column#7", + " └─Projection(Probe) 10000.00 cop[tiflash] minus(test.t.id, 2)->Column#9", " └─TableFullScan 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo" ] }, @@ -1848,14 +1856,14 @@ "Plan": [ "TableReader 10000.00 root data:ExchangeSender", "└─ExchangeSender 10000.00 cop[tiflash] ExchangeType: PassThrough", - " └─Projection 10000.00 cop[tiflash] Column#8, Column#4", - " └─HashJoin 10000.00 cop[tiflash] inner join, equal:[eq(Column#4, Column#8)]", + " └─Projection 10000.00 cop[tiflash] Column#10, Column#5", + " └─HashJoin 10000.00 cop[tiflash] inner join, equal:[eq(Column#5, Column#10)]", " ├─ExchangeReceiver(Build) 8000.00 cop[tiflash] ", " │ └─ExchangeSender 8000.00 cop[tiflash] ExchangeType: Broadcast", - " │ └─Projection 8000.00 cop[tiflash] minus(test.t.id, 2)->Column#4", + " │ └─Projection 8000.00 cop[tiflash] minus(test.t.id, 2)->Column#5", " │ └─Selection 8000.00 cop[tiflash] not(isnull(minus(test.t.id, 2)))", " │ └─TableFullScan 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo", - " └─Projection(Probe) 8000.00 cop[tiflash] minus(test.t.id, 2)->Column#8", + " └─Projection(Probe) 8000.00 cop[tiflash] minus(test.t.id, 2)->Column#10", " └─Selection 8000.00 cop[tiflash] not(isnull(minus(test.t.id, 2)))", " └─TableFullScan 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo" ] @@ -1892,15 +1900,24 @@ "Sort 8000.00 root test.t.id", "└─TableReader 8000.00 root data:ExchangeSender", " └─ExchangeSender 8000.00 batchCop[tiflash] ExchangeType: PassThrough", - " └─Projection 8000.00 batchCop[tiflash] mul(Column#4, 2)->Column#5, test.t.id", - " └─Projection 8000.00 batchCop[tiflash] div(Column#4, cast(case(eq(Column#19, 0), 1, Column#19), decimal(20,0) BINARY))->Column#4, test.t.id", - " └─HashAgg 8000.00 batchCop[tiflash] group by:test.t.id, funcs:sum(Column#20)->Column#19, funcs:sum(Column#21)->Column#4, funcs:firstrow(test.t.id)->test.t.id", + " └─Projection 8000.00 batchCop[tiflash] mul(Column#5, 2)->Column#6, test.t.id", + " └─Projection 8000.00 batchCop[tiflash] div(Column#5, cast(case(eq(Column#20, 0), 1, Column#20), decimal(20,0) BINARY))->Column#5, test.t.id", + " └─HashAgg 8000.00 batchCop[tiflash] group by:test.t.id, funcs:sum(Column#21)->Column#20, funcs:sum(Column#22)->Column#5, funcs:firstrow(test.t.id)->test.t.id", " └─ExchangeReceiver 8000.00 batchCop[tiflash] ", " └─ExchangeSender 8000.00 batchCop[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.id, collate: N/A]", - " └─HashAgg 8000.00 batchCop[tiflash] group by:Column#25, funcs:count(Column#23)->Column#20, funcs:sum(Column#24)->Column#21", - " └─Projection 10000.00 batchCop[tiflash] plus(test.t.value, 2)->Column#23, plus(test.t.value, 2)->Column#24, test.t.id", + " └─HashAgg 8000.00 batchCop[tiflash] group by:Column#26, funcs:count(Column#24)->Column#21, funcs:sum(Column#25)->Column#22", + " └─Projection 10000.00 batchCop[tiflash] plus(test.t.value, 2)->Column#24, plus(test.t.value, 2)->Column#25, test.t.id", " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ] + }, + { + "SQL": "desc format = 'brief' SELECT FROM_UNIXTIME(name,'%Y-%m-%d') FROM t;", + "Plan": [ + "TableReader 10000.00 root data:ExchangeSender", + "└─ExchangeSender 10000.00 cop[tiflash] ExchangeType: PassThrough", + " └─Projection 10000.00 cop[tiflash] from_unixtime(cast(test.t.name, decimal(65,0) BINARY), %Y-%m-%d)->Column#5", + " └─TableFullScan 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo" + ] } ] }, diff --git a/planner/core/testdata/plan_suite_out.json b/planner/core/testdata/plan_suite_out.json index 7c839190b4282..4d2c468590362 100644 --- a/planner/core/testdata/plan_suite_out.json +++ b/planner/core/testdata/plan_suite_out.json @@ -1580,7 +1580,7 @@ " └─IndexFullScan 1.00 cop[tikv] table:tn, index:a(a, b, c, d) keep order:true, stats:pseudo" ], "Warning": [ - "Scalar function 'intdiv'(signature: IntDivideInt) can not be pushed to storage layer", + "Scalar function 'intdiv'(signature: IntDivideInt, return type: bigint(20)) can not be pushed to storage layer", "[planner:1815]Optimizer Hint LIMIT_TO_COP is inapplicable" ] }, diff --git a/privilege/privileges/cache.go b/privilege/privileges/cache.go index be55f4afade08..c8682684b2489 100644 --- a/privilege/privileges/cache.go +++ b/privilege/privileges/cache.go @@ -53,7 +53,7 @@ const globalDBVisible = mysql.CreatePriv | mysql.SelectPriv | mysql.InsertPriv | const ( sqlLoadRoleGraph = "SELECT HIGH_PRIORITY FROM_USER, FROM_HOST, TO_USER, TO_HOST FROM mysql.role_edges" sqlLoadGlobalPrivTable = "SELECT HIGH_PRIORITY Host,User,Priv FROM mysql.global_priv" - sqlLoadDBTable = "SELECT HIGH_PRIORITY Host,DB,User,Select_priv,Insert_priv,Update_priv,Delete_priv,Create_priv,Drop_priv,Grant_priv,Index_priv,References_priv,Lock_tables_priv,Alter_priv,Execute_priv,Create_view_priv,Show_view_priv FROM mysql.db ORDER BY host, db, user" + sqlLoadDBTable = "SELECT HIGH_PRIORITY Host,DB,User,Select_priv,Insert_priv,Update_priv,Delete_priv,Create_priv,Drop_priv,Grant_priv,Index_priv,References_priv,Lock_tables_priv,Create_tmp_table_priv,Alter_priv,Execute_priv,Create_view_priv,Show_view_priv FROM mysql.db ORDER BY host, db, user" sqlLoadTablePrivTable = "SELECT HIGH_PRIORITY Host,DB,User,Table_name,Grantor,Timestamp,Table_priv,Column_priv FROM mysql.tables_priv" sqlLoadColumnsPrivTable = "SELECT HIGH_PRIORITY Host,DB,User,Table_name,Column_name,Timestamp,Column_priv FROM mysql.columns_priv" sqlLoadDefaultRoles = "SELECT HIGH_PRIORITY HOST, USER, DEFAULT_ROLE_HOST, DEFAULT_ROLE_USER FROM mysql.default_roles" diff --git a/privilege/privileges/privileges_test.go b/privilege/privileges/privileges_test.go index f6192e1a90a14..1f66a2cf25217 100644 --- a/privilege/privileges/privileges_test.go +++ b/privilege/privileges/privileges_test.go @@ -2461,7 +2461,7 @@ func TestPlacementPolicyStmt(t *testing.T) { defer clean() se := newSession(t, store, dbName) mustExec(t, se, "drop placement policy if exists x") - createStmt := "create placement policy x PRIMARY_REGION=\"cn-east-1\" " + createStmt := "create placement policy x PRIMARY_REGION=\"cn-east-1\" REGIONS=\"cn-east-1\"" dropStmt := "drop placement policy if exists x" // high privileged user setting password for other user (passes) @@ -2492,3 +2492,27 @@ func TestDBNameCaseSensitivityInTableLevel(t *testing.T) { mustExec(t, se, "CREATE USER test_user") mustExec(t, se, "grant select on metrics_schema.up to test_user;") } + +func TestGrantCreateTmpTables(t *testing.T) { + t.Parallel() + store, clean := newStore(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + tk.MustExec("CREATE DATABASE create_tmp_table_db") + tk.MustExec("USE create_tmp_table_db") + tk.MustExec("CREATE USER u1") + tk.MustExec("CREATE TABLE create_tmp_table_table (a int)") + tk.MustExec("GRANT CREATE TEMPORARY TABLES on create_tmp_table_db.* to u1") + tk.MustExec("GRANT CREATE TEMPORARY TABLES on *.* to u1") + // Must set a session user to avoid null pointer dereferencing + tk.Session().Auth(&auth.UserIdentity{ + Username: "root", + Hostname: "localhost", + }, nil, nil) + tk.MustQuery("SHOW GRANTS FOR u1").Check(testkit.Rows( + `GRANT CREATE TEMPORARY TABLES ON *.* TO 'u1'@'%'`, + `GRANT CREATE TEMPORARY TABLES ON create_tmp_table_db.* TO 'u1'@'%'`)) + tk.MustExec("DROP USER u1") + tk.MustExec("DROP DATABASE create_tmp_table_db") +} diff --git a/session/bootstrap.go b/session/bootstrap.go index 73c7e23967282..8a1044e59a20a 100644 --- a/session/bootstrap.go +++ b/session/bootstrap.go @@ -132,7 +132,7 @@ const ( Grantor CHAR(77), Timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, Table_priv SET('Select','Insert','Update','Delete','Create','Drop','Grant','Index','Alter','Create View','Show View','Trigger','References'), - Column_priv SET('Select','Insert','Update'), + Column_priv SET('Select','Insert','Update','References'), PRIMARY KEY (Host, DB, User, Table_name));` // CreateColumnPrivTable is the SQL statement creates column scope privilege table in system db. CreateColumnPrivTable = `CREATE TABLE IF NOT EXISTS mysql.columns_priv( @@ -142,7 +142,7 @@ const ( Table_name CHAR(64), Column_name CHAR(64), Timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - Column_priv SET('Select','Insert','Update'), + Column_priv SET('Select','Insert','Update','References'), PRIMARY KEY (Host, DB, User, Table_name, Column_name));` // CreateGlobalVariablesTable is the SQL statement creates global variable table in system db. // TODO: MySQL puts GLOBAL_VARIABLES table in INFORMATION_SCHEMA db. @@ -513,11 +513,13 @@ const ( version74 = 74 // version75 update mysql.*.host from char(60) to char(255) version75 = 75 + // version76 update mysql.columns_priv from SET('Select','Insert','Update') to SET('Select','Insert','Update','References') + version76 = 76 ) // currentBootstrapVersion is defined as a variable, so we can modify its value for testing. // please make sure this is the largest version -var currentBootstrapVersion int64 = version75 +var currentBootstrapVersion int64 = version76 var ( bootstrapVersion = []func(Session, int64){ @@ -596,6 +598,7 @@ var ( upgradeToVer73, upgradeToVer74, upgradeToVer75, + upgradeToVer76, } ) @@ -1571,6 +1574,13 @@ func upgradeToVer75(s Session, ver int64) { doReentrantDDL(s, "ALTER TABLE mysql.columns_priv MODIFY COLUMN Host CHAR(255)") } +func upgradeToVer76(s Session, ver int64) { + if ver >= version76 { + return + } + doReentrantDDL(s, "ALTER TABLE mysql.columns_priv MODIFY COLUMN Column_priv SET('Select','Insert','Update','References')") +} + func writeOOMAction(s Session) { comment := "oom-action is `log` by default in v3.0.x, `cancel` by default in v4.0.11+" mustExecute(s, `INSERT HIGH_PRIORITY INTO %n.%n VALUES (%?, %?, %?) ON DUPLICATE KEY UPDATE VARIABLE_VALUE= %?`, diff --git a/session/bootstrap_test.go b/session/bootstrap_test.go index 68b46e75c98ab..23c31c78bf30a 100644 --- a/session/bootstrap_test.go +++ b/session/bootstrap_test.go @@ -845,3 +845,21 @@ func (s *testBootstrapSuite) TestForIssue23387(c *C) { c.Assert(len(rows), Equals, 1) c.Assert(rows[0][0], Equals, "GRANT USAGE ON *.* TO 'quatest'@'%'") } + +func (s *testBootstrapSuite) TestReferencesPrivOnCol(c *C) { + defer testleak.AfterTest(c)() + store, dom := newStoreWithBootstrap(c, s.dbName) + defer store.Close() + defer dom.Close() + se := newSession(c, store, s.dbName) + + defer func() { + mustExecSQL(c, se, "drop user if exists issue28531") + mustExecSQL(c, se, "drop table if exists t1") + }() + + mustExecSQL(c, se, "create user if not exists issue28531") + mustExecSQL(c, se, "drop table if exists t1") + mustExecSQL(c, se, "create table t1 (a int)") + mustExecSQL(c, se, "GRANT select (a), update (a),insert(a), references(a) on t1 to issue28531") +} diff --git a/session/session.go b/session/session.go index 18c8191c35dfb..4897a6ff44708 100644 --- a/session/session.go +++ b/session/session.go @@ -926,7 +926,10 @@ func (s *session) retry(ctx context.Context, maxCnt uint) (err error) { zap.Uint("retryCnt", retryCnt), zap.Int("queryNum", i)) } + _, digest := s.sessionVars.StmtCtx.SQLDigest() + s.txn.onStmtStart(digest.String()) _, err = st.Exec(ctx) + s.txn.onStmtEnd() if err != nil { s.StmtRollback() break diff --git a/statistics/feedback_test.go b/statistics/feedback_test.go index 72fc106b99ded..ff4d2b4578ba8 100644 --- a/statistics/feedback_test.go +++ b/statistics/feedback_test.go @@ -16,20 +16,16 @@ package statistics import ( "bytes" + "testing" - . "github.com/pingcap/check" "github.com/pingcap/log" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/codec" + "github.com/stretchr/testify/require" "go.uber.org/zap" ) -var _ = Suite(&testFeedbackSuite{}) - -type testFeedbackSuite struct { -} - func newFeedback(lower, upper, count, ndv int64) Feedback { low, upp := types.NewIntDatum(lower), types.NewIntDatum(upper) return Feedback{&low, &upp, count, 0, ndv} @@ -58,7 +54,8 @@ func genHistogram() *Histogram { return h } -func (s *testFeedbackSuite) TestUpdateHistogram(c *C) { +func TestUpdateHistogram(t *testing.T) { + t.Parallel() feedbacks := []Feedback{ newFeedback(0, 1, 10000, 1), newFeedback(1, 2, 1, 1), @@ -74,17 +71,19 @@ func (s *testFeedbackSuite) TestUpdateHistogram(c *C) { originBucketCount := defaultBucketCount defaultBucketCount = 7 defer func() { defaultBucketCount = originBucketCount }() - c.Assert(UpdateHistogram(q.Hist, q, Version2).ToString(0), Equals, + require.Equal(t, "column:0 ndv:10053 totColSize:0\n"+ "num: 10001 lower_bound: 0 upper_bound: 2 repeats: 0 ndv: 2\n"+ "num: 7 lower_bound: 2 upper_bound: 5 repeats: 0 ndv: 2\n"+ "num: 4 lower_bound: 5 upper_bound: 7 repeats: 0 ndv: 1\n"+ "num: 11 lower_bound: 10 upper_bound: 20 repeats: 0 ndv: 11\n"+ "num: 19 lower_bound: 30 upper_bound: 49 repeats: 0 ndv: 19\n"+ - "num: 11 lower_bound: 50 upper_bound: 60 repeats: 0 ndv: 11") + "num: 11 lower_bound: 50 upper_bound: 60 repeats: 0 ndv: 11", + UpdateHistogram(q.Hist, q, Version2).ToString(0)) } -func (s *testFeedbackSuite) TestSplitBuckets(c *C) { +func TestSplitBuckets(t *testing.T) { + t.Parallel() // test bucket split feedbacks := []Feedback{newFeedback(0, 1, 1, 1)} for i := 0; i < 100; i++ { @@ -107,16 +106,17 @@ func (s *testFeedbackSuite) TestSplitBuckets(c *C) { ndvs[i] = buckets[i].Ndv } log.Warn("in test", zap.Int64s("ndvs", ndvs)) - c.Assert(buildNewHistogram(q.Hist, buckets).ToString(0), Equals, + require.Equal(t, "column:0 ndv:0 totColSize:0\n"+ "num: 1 lower_bound: 0 upper_bound: 1 repeats: 0 ndv: 1\n"+ "num: 0 lower_bound: 2 upper_bound: 3 repeats: 0 ndv: 0\n"+ "num: 0 lower_bound: 5 upper_bound: 7 repeats: 0 ndv: 0\n"+ "num: 5 lower_bound: 10 upper_bound: 15 repeats: 0 ndv: 5\n"+ "num: 0 lower_bound: 16 upper_bound: 20 repeats: 0 ndv: 0\n"+ - "num: 0 lower_bound: 30 upper_bound: 50 repeats: 0 ndv: 0") - c.Assert(isNewBuckets, DeepEquals, []bool{false, false, false, true, true, false}) - c.Assert(totalCount, Equals, int64(6)) + "num: 0 lower_bound: 30 upper_bound: 50 repeats: 0 ndv: 0", + buildNewHistogram(q.Hist, buckets).ToString(0)) + require.Equal(t, []bool{false, false, false, true, true, false}, isNewBuckets) + require.Equal(t, int64(6), totalCount) // test do not split if the bucket count is too small feedbacks = []Feedback{newFeedback(0, 1, 100000, 1)} @@ -126,16 +126,17 @@ func (s *testFeedbackSuite) TestSplitBuckets(c *C) { q = NewQueryFeedback(0, genHistogram(), 0, false) q.Feedback = feedbacks buckets, isNewBuckets, totalCount = splitBuckets(q.Hist, q) - c.Assert(buildNewHistogram(q.Hist, buckets).ToString(0), Equals, + require.Equal(t, "column:0 ndv:0 totColSize:0\n"+ "num: 100000 lower_bound: 0 upper_bound: 1 repeats: 0 ndv: 1\n"+ "num: 0 lower_bound: 2 upper_bound: 3 repeats: 0 ndv: 0\n"+ "num: 0 lower_bound: 5 upper_bound: 7 repeats: 0 ndv: 0\n"+ "num: 1 lower_bound: 10 upper_bound: 15 repeats: 0 ndv: 1\n"+ "num: 0 lower_bound: 16 upper_bound: 20 repeats: 0 ndv: 0\n"+ - "num: 0 lower_bound: 30 upper_bound: 50 repeats: 0 ndv: 0") - c.Assert(isNewBuckets, DeepEquals, []bool{false, false, false, true, true, false}) - c.Assert(totalCount, Equals, int64(100001)) + "num: 0 lower_bound: 30 upper_bound: 50 repeats: 0 ndv: 0", + buildNewHistogram(q.Hist, buckets).ToString(0)) + require.Equal(t, []bool{false, false, false, true, true, false}, isNewBuckets) + require.Equal(t, int64(100001), totalCount) // test do not split if the result bucket count is too small h := NewHistogram(0, 0, 0, 0, types.NewFieldType(mysql.TypeLong), 5, 0) @@ -149,11 +150,12 @@ func (s *testFeedbackSuite) TestSplitBuckets(c *C) { q = NewQueryFeedback(0, h, 0, false) q.Feedback = feedbacks buckets, isNewBuckets, totalCount = splitBuckets(q.Hist, q) - c.Assert(buildNewHistogram(q.Hist, buckets).ToString(0), Equals, + require.Equal(t, "column:0 ndv:0 totColSize:0\n"+ - "num: 1000000 lower_bound: 0 upper_bound: 1000000 repeats: 0 ndv: 1000000") - c.Assert(isNewBuckets, DeepEquals, []bool{false}) - c.Assert(totalCount, Equals, int64(1000000)) + "num: 1000000 lower_bound: 0 upper_bound: 1000000 repeats: 0 ndv: 1000000", + buildNewHistogram(q.Hist, buckets).ToString(0)) + require.Equal(t, []bool{false}, isNewBuckets) + require.Equal(t, int64(1000000), totalCount) // test split even if the feedback range is too small h = NewHistogram(0, 0, 0, 0, types.NewFieldType(mysql.TypeLong), 5, 0) @@ -165,12 +167,13 @@ func (s *testFeedbackSuite) TestSplitBuckets(c *C) { q = NewQueryFeedback(0, h, 0, false) q.Feedback = feedbacks buckets, isNewBuckets, totalCount = splitBuckets(q.Hist, q) - c.Assert(buildNewHistogram(q.Hist, buckets).ToString(0), Equals, + require.Equal(t, "column:0 ndv:0 totColSize:0\n"+ "num: 1 lower_bound: 0 upper_bound: 10 repeats: 0 ndv: 1\n"+ - "num: 0 lower_bound: 11 upper_bound: 1000000 repeats: 0 ndv: 0") - c.Assert(isNewBuckets, DeepEquals, []bool{true, true}) - c.Assert(totalCount, Equals, int64(1)) + "num: 0 lower_bound: 11 upper_bound: 1000000 repeats: 0 ndv: 0", + buildNewHistogram(q.Hist, buckets).ToString(0)) + require.Equal(t, []bool{true, true}, isNewBuckets) + require.Equal(t, int64(1), totalCount) // test merge the non-overlapped feedbacks. h = NewHistogram(0, 0, 0, 0, types.NewFieldType(mysql.TypeLong), 5, 0) @@ -181,14 +184,16 @@ func (s *testFeedbackSuite) TestSplitBuckets(c *C) { q = NewQueryFeedback(0, h, 0, false) q.Feedback = feedbacks buckets, isNewBuckets, totalCount = splitBuckets(q.Hist, q) - c.Assert(buildNewHistogram(q.Hist, buckets).ToString(0), Equals, + require.Equal(t, "column:0 ndv:0 totColSize:0\n"+ - "num: 5001 lower_bound: 0 upper_bound: 10000 repeats: 0 ndv: 5001") - c.Assert(isNewBuckets, DeepEquals, []bool{false}) - c.Assert(totalCount, Equals, int64(5001)) + "num: 5001 lower_bound: 0 upper_bound: 10000 repeats: 0 ndv: 5001", + buildNewHistogram(q.Hist, buckets).ToString(0)) + require.Equal(t, []bool{false}, isNewBuckets) + require.Equal(t, int64(5001), totalCount) } -func (s *testFeedbackSuite) TestMergeBuckets(c *C) { +func TestMergeBuckets(t *testing.T) { + t.Parallel() originBucketCount := defaultBucketCount defer func() { defaultBucketCount = originBucketCount }() tests := []struct { @@ -230,21 +235,19 @@ func (s *testFeedbackSuite) TestMergeBuckets(c *C) { "num: 100000 lower_bound: 4 upper_bound: 5 repeats: 0 ndv: 1", }, } - for _, t := range tests { - if len(t.counts) != len(t.ndvs) { - c.Assert(false, IsTrue) - } - bkts := make([]bucket, 0, len(t.counts)) + for _, tt := range tests { + require.Equal(t, len(tt.ndvs), len(tt.counts)) + bkts := make([]bucket, 0, len(tt.counts)) totalCount := int64(0) - for i := 0; i < len(t.counts); i++ { - lower, upper := types.NewIntDatum(t.points[2*i]), types.NewIntDatum(t.points[2*i+1]) - bkts = append(bkts, bucket{&lower, &upper, t.counts[i], 0, t.ndvs[i]}) - totalCount += t.counts[i] + for i := 0; i < len(tt.counts); i++ { + lower, upper := types.NewIntDatum(tt.points[2*i]), types.NewIntDatum(tt.points[2*i+1]) + bkts = append(bkts, bucket{&lower, &upper, tt.counts[i], 0, tt.ndvs[i]}) + totalCount += tt.counts[i] } - defaultBucketCount = t.bucketCount - bkts = mergeBuckets(bkts, t.isNewBuckets, float64(totalCount)) + defaultBucketCount = tt.bucketCount + bkts = mergeBuckets(bkts, tt.isNewBuckets, float64(totalCount)) result := buildNewHistogram(&Histogram{Tp: types.NewFieldType(mysql.TypeLong)}, bkts).ToString(0) - c.Assert(result, Equals, t.result) + require.Equal(t, tt.result, result) } } @@ -254,33 +257,34 @@ func encodeInt(v int64) *types.Datum { return &d } -func (s *testFeedbackSuite) TestFeedbackEncoding(c *C) { +func TestFeedbackEncoding(t *testing.T) { + t.Parallel() hist := NewHistogram(0, 0, 0, 0, types.NewFieldType(mysql.TypeLong), 0, 0) q := &QueryFeedback{Hist: hist, Tp: PkType} q.Feedback = append(q.Feedback, Feedback{encodeInt(0), encodeInt(3), 1, 0, 1}) q.Feedback = append(q.Feedback, Feedback{encodeInt(0), encodeInt(5), 1, 0, 1}) val, err := EncodeFeedback(q) - c.Assert(err, IsNil) + require.NoError(t, err) rq := &QueryFeedback{} - c.Assert(DecodeFeedback(val, rq, nil, nil, hist.Tp), IsNil) + require.NoError(t, DecodeFeedback(val, rq, nil, nil, hist.Tp)) for _, fb := range rq.Feedback { fb.Lower.SetBytes(codec.EncodeInt(nil, fb.Lower.GetInt64())) fb.Upper.SetBytes(codec.EncodeInt(nil, fb.Upper.GetInt64())) } - c.Assert(q.Equal(rq), IsTrue) + require.True(t, q.Equal(rq)) hist.Tp = types.NewFieldType(mysql.TypeBlob) q = &QueryFeedback{Hist: hist} q.Feedback = append(q.Feedback, Feedback{encodeInt(0), encodeInt(3), 1, 0, 1}) q.Feedback = append(q.Feedback, Feedback{encodeInt(0), encodeInt(1), 1, 0, 1}) val, err = EncodeFeedback(q) - c.Assert(err, IsNil) + require.NoError(t, err) rq = &QueryFeedback{} cms := NewCMSketch(4, 4) - c.Assert(DecodeFeedback(val, rq, cms, nil, hist.Tp), IsNil) - c.Assert(cms.QueryBytes(codec.EncodeInt(nil, 0)), Equals, uint64(1)) + require.NoError(t, DecodeFeedback(val, rq, cms, nil, hist.Tp)) + require.Equal(t, uint64(1), cms.QueryBytes(codec.EncodeInt(nil, 0))) q.Feedback = q.Feedback[:1] - c.Assert(q.Equal(rq), IsTrue) + require.True(t, q.Equal(rq)) } // Equal tests if two query feedback equal, it is only used in test. diff --git a/statistics/sample_serial_test.go b/statistics/sample_serial_test.go new file mode 100644 index 0000000000000..94f30b11f32f4 --- /dev/null +++ b/statistics/sample_serial_test.go @@ -0,0 +1,153 @@ +// Copyright 2017 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package statistics + +import ( + "testing" + "time" + + "github.com/pingcap/parser/mysql" + "github.com/pingcap/tidb/sessionctx/stmtctx" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/collate" + "github.com/pingcap/tidb/util/mock" + "github.com/pingcap/tidb/util/sqlexec" + "github.com/stretchr/testify/require" +) + +type testSampleSuite struct { + count int + rs sqlexec.RecordSet +} + +func TestSampleSerial(t *testing.T) { + s := createTestSampleSuite() + t.Run("SubTestCollectColumnStats", SubTestCollectColumnStats(s)) + t.Run("SubTestMergeSampleCollector", SubTestMergeSampleCollector(s)) + t.Run("SubTestCollectorProtoConversion", SubTestCollectorProtoConversion(s)) +} + +func createTestSampleSuite() *testSampleSuite { + s := new(testSampleSuite) + s.count = 10000 + rs := &recordSet{ + data: make([]types.Datum, s.count), + count: s.count, + cursor: 0, + firstIsID: true, + } + rs.setFields(mysql.TypeLonglong, mysql.TypeLonglong) + start := 1000 // 1000 values is null + for i := start; i < rs.count; i++ { + rs.data[i].SetInt64(int64(i)) + } + for i := start; i < rs.count; i += 3 { + rs.data[i].SetInt64(rs.data[i].GetInt64() + 1) + } + for i := start; i < rs.count; i += 5 { + rs.data[i].SetInt64(rs.data[i].GetInt64() + 2) + } + s.rs = rs + return s +} + +func SubTestCollectColumnStats(s *testSampleSuite) func(*testing.T) { + return func(t *testing.T) { + sc := mock.NewContext().GetSessionVars().StmtCtx + builder := SampleBuilder{ + Sc: sc, + RecordSet: s.rs, + ColLen: 1, + PkBuilder: NewSortedBuilder(sc, 256, 1, types.NewFieldType(mysql.TypeLonglong), Version2), + MaxSampleSize: 10000, + MaxBucketSize: 256, + MaxFMSketchSize: 1000, + CMSketchWidth: 2048, + CMSketchDepth: 8, + Collators: make([]collate.Collator, 1), + ColsFieldType: []*types.FieldType{types.NewFieldType(mysql.TypeLonglong)}, + } + require.Nil(t, s.rs.Close()) + collectors, pkBuilder, err := builder.CollectColumnStats() + require.NoError(t, err) + + require.Equal(t, int64(s.count), collectors[0].NullCount+collectors[0].Count) + require.Equal(t, int64(6232), collectors[0].FMSketch.NDV()) + require.Equal(t, uint64(collectors[0].Count), collectors[0].CMSketch.TotalCount()) + require.Equal(t, int64(s.count), pkBuilder.Count) + require.Equal(t, int64(s.count), pkBuilder.Hist().NDV) + } +} + +func SubTestMergeSampleCollector(s *testSampleSuite) func(*testing.T) { + return func(t *testing.T) { + builder := SampleBuilder{ + Sc: mock.NewContext().GetSessionVars().StmtCtx, + RecordSet: s.rs, + ColLen: 2, + MaxSampleSize: 1000, + MaxBucketSize: 256, + MaxFMSketchSize: 1000, + CMSketchWidth: 2048, + CMSketchDepth: 8, + Collators: make([]collate.Collator, 2), + ColsFieldType: []*types.FieldType{types.NewFieldType(mysql.TypeLonglong), types.NewFieldType(mysql.TypeLonglong)}, + } + require.Nil(t, s.rs.Close()) + sc := &stmtctx.StatementContext{TimeZone: time.Local} + collectors, pkBuilder, err := builder.CollectColumnStats() + require.NoError(t, err) + require.Nil(t, pkBuilder) + require.Len(t, collectors, 2) + collectors[0].IsMerger = true + collectors[0].MergeSampleCollector(sc, collectors[1]) + require.Equal(t, int64(9280), collectors[0].FMSketch.NDV()) + require.Len(t, collectors[0].Samples, 1000) + require.Equal(t, int64(1000), collectors[0].NullCount) + require.Equal(t, int64(19000), collectors[0].Count) + require.Equal(t, uint64(collectors[0].Count), collectors[0].CMSketch.TotalCount()) + } +} + +func SubTestCollectorProtoConversion(s *testSampleSuite) func(*testing.T) { + return func(t *testing.T) { + builder := SampleBuilder{ + Sc: mock.NewContext().GetSessionVars().StmtCtx, + RecordSet: s.rs, + ColLen: 2, + MaxSampleSize: 10000, + MaxBucketSize: 256, + MaxFMSketchSize: 1000, + CMSketchWidth: 2048, + CMSketchDepth: 8, + Collators: make([]collate.Collator, 2), + ColsFieldType: []*types.FieldType{types.NewFieldType(mysql.TypeLonglong), types.NewFieldType(mysql.TypeLonglong)}, + } + require.Nil(t, s.rs.Close()) + collectors, pkBuilder, err := builder.CollectColumnStats() + require.NoError(t, err) + require.Nil(t, pkBuilder) + for _, collector := range collectors { + p := SampleCollectorToProto(collector) + s := SampleCollectorFromProto(p) + require.Equal(t, s.Count, collector.Count) + require.Equal(t, s.NullCount, collector.NullCount) + require.Equal(t, s.CMSketch.TotalCount(), collector.CMSketch.TotalCount()) + require.Equal(t, s.FMSketch.NDV(), collector.FMSketch.NDV()) + require.Equal(t, s.TotalSize, collector.TotalSize) + require.Equal(t, len(s.Samples), len(collector.Samples)) + } + } +} diff --git a/statistics/sample_test.go b/statistics/sample_test.go index 7553cd6f3cfae..082d7d8016e83 100644 --- a/statistics/sample_test.go +++ b/statistics/sample_test.go @@ -16,129 +16,17 @@ package statistics import ( "math/rand" + "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/parser/mysql" - "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tidb/util/mock" - "github.com/pingcap/tidb/util/sqlexec" + "github.com/stretchr/testify/require" ) -var _ = Suite(&testSampleSuite{}) - -type testSampleSuite struct { - count int - rs sqlexec.RecordSet -} - -func (s *testSampleSuite) SetUpSuite(c *C) { - s.count = 10000 - rs := &recordSet{ - data: make([]types.Datum, s.count), - count: s.count, - cursor: 0, - firstIsID: true, - } - rs.setFields(mysql.TypeLonglong, mysql.TypeLonglong) - start := 1000 // 1000 values is null - for i := start; i < rs.count; i++ { - rs.data[i].SetInt64(int64(i)) - } - for i := start; i < rs.count; i += 3 { - rs.data[i].SetInt64(rs.data[i].GetInt64() + 1) - } - for i := start; i < rs.count; i += 5 { - rs.data[i].SetInt64(rs.data[i].GetInt64() + 2) - } - s.rs = rs -} - -func (s *testSampleSuite) TestCollectColumnStats(c *C) { - sc := mock.NewContext().GetSessionVars().StmtCtx - builder := SampleBuilder{ - Sc: sc, - RecordSet: s.rs, - ColLen: 1, - PkBuilder: NewSortedBuilder(sc, 256, 1, types.NewFieldType(mysql.TypeLonglong), Version2), - MaxSampleSize: 10000, - MaxBucketSize: 256, - MaxFMSketchSize: 1000, - CMSketchWidth: 2048, - CMSketchDepth: 8, - Collators: make([]collate.Collator, 1), - ColsFieldType: []*types.FieldType{types.NewFieldType(mysql.TypeLonglong)}, - } - c.Assert(s.rs.Close(), IsNil) - collectors, pkBuilder, err := builder.CollectColumnStats() - c.Assert(err, IsNil) - c.Assert(collectors[0].NullCount+collectors[0].Count, Equals, int64(s.count)) - c.Assert(collectors[0].FMSketch.NDV(), Equals, int64(6232)) - c.Assert(collectors[0].CMSketch.TotalCount(), Equals, uint64(collectors[0].Count)) - c.Assert(pkBuilder.Count, Equals, int64(s.count)) - c.Assert(pkBuilder.Hist().NDV, Equals, int64(s.count)) -} - -func (s *testSampleSuite) TestMergeSampleCollector(c *C) { - builder := SampleBuilder{ - Sc: mock.NewContext().GetSessionVars().StmtCtx, - RecordSet: s.rs, - ColLen: 2, - MaxSampleSize: 1000, - MaxBucketSize: 256, - MaxFMSketchSize: 1000, - CMSketchWidth: 2048, - CMSketchDepth: 8, - Collators: make([]collate.Collator, 2), - ColsFieldType: []*types.FieldType{types.NewFieldType(mysql.TypeLonglong), types.NewFieldType(mysql.TypeLonglong)}, - } - c.Assert(s.rs.Close(), IsNil) - sc := &stmtctx.StatementContext{TimeZone: time.Local} - collectors, pkBuilder, err := builder.CollectColumnStats() - c.Assert(err, IsNil) - c.Assert(pkBuilder, IsNil) - c.Assert(len(collectors), Equals, 2) - collectors[0].IsMerger = true - collectors[0].MergeSampleCollector(sc, collectors[1]) - c.Assert(collectors[0].FMSketch.NDV(), Equals, int64(9280)) - c.Assert(len(collectors[0].Samples), Equals, 1000) - c.Assert(collectors[0].NullCount, Equals, int64(1000)) - c.Assert(collectors[0].Count, Equals, int64(19000)) - c.Assert(collectors[0].CMSketch.TotalCount(), Equals, uint64(collectors[0].Count)) -} - -func (s *testSampleSuite) TestCollectorProtoConversion(c *C) { - builder := SampleBuilder{ - Sc: mock.NewContext().GetSessionVars().StmtCtx, - RecordSet: s.rs, - ColLen: 2, - MaxSampleSize: 10000, - MaxBucketSize: 256, - MaxFMSketchSize: 1000, - CMSketchWidth: 2048, - CMSketchDepth: 8, - Collators: make([]collate.Collator, 2), - ColsFieldType: []*types.FieldType{types.NewFieldType(mysql.TypeLonglong), types.NewFieldType(mysql.TypeLonglong)}, - } - c.Assert(s.rs.Close(), IsNil) - collectors, pkBuilder, err := builder.CollectColumnStats() - c.Assert(err, IsNil) - c.Assert(pkBuilder, IsNil) - for _, collector := range collectors { - p := SampleCollectorToProto(collector) - s := SampleCollectorFromProto(p) - c.Assert(collector.Count, Equals, s.Count) - c.Assert(collector.NullCount, Equals, s.NullCount) - c.Assert(collector.CMSketch.TotalCount(), Equals, s.CMSketch.TotalCount()) - c.Assert(collector.FMSketch.NDV(), Equals, s.FMSketch.NDV()) - c.Assert(collector.TotalSize, Equals, s.TotalSize) - c.Assert(len(collector.Samples), Equals, len(s.Samples)) - } -} - -func (s *testSampleSuite) recordSetForWeightSamplingTest(size int) *recordSet { +func recordSetForWeightSamplingTest(size int) *recordSet { r := &recordSet{ data: make([]types.Datum, 0, size), count: size, @@ -150,7 +38,7 @@ func (s *testSampleSuite) recordSetForWeightSamplingTest(size int) *recordSet { return r } -func (s *testSampleSuite) recordSetForDistributedSamplingTest(size, batch int) []*recordSet { +func recordSetForDistributedSamplingTest(size, batch int) []*recordSet { sets := make([]*recordSet, 0, batch) batchSize := size / batch for i := 0; i < batch; i++ { @@ -167,11 +55,12 @@ func (s *testSampleSuite) recordSetForDistributedSamplingTest(size, batch int) [ return sets } -func (s *testSampleSuite) TestWeightedSampling(c *C) { +func TestWeightedSampling(t *testing.T) { + t.Parallel() sampleNum := int64(20) rowNum := 100 loopCnt := 1000 - rs := s.recordSetForWeightSamplingTest(rowNum) + rs := recordSetForWeightSamplingTest(rowNum) sc := mock.NewContext().GetSessionVars().StmtCtx // The loop which is commented out is used for stability test. // This test can run 800 times in a row without any failure. @@ -189,29 +78,29 @@ func (s *testSampleSuite) TestWeightedSampling(c *C) { Rng: rand.New(rand.NewSource(time.Now().UnixNano())), } collector, err := builder.Collect() - c.Assert(err, IsNil) + require.NoError(t, err) for i := 0; i < collector.MaxSampleSize; i++ { a := collector.Samples[i].Columns[0].GetInt64() itemCnt[a]++ } - c.Assert(rs.Close(), IsNil) + require.Nil(t, rs.Close()) } expFrequency := float64(sampleNum) * float64(loopCnt) / float64(rowNum) delta := 0.5 for _, cnt := range itemCnt { if float64(cnt) < expFrequency/(1+delta) || float64(cnt) > expFrequency*(1+delta) { - c.Assert(false, IsTrue, Commentf("The frequency %v is exceed the Chernoff Bound", cnt)) + require.Truef(t, false, "The frequency %v is exceed the Chernoff Bound", cnt) } } - // } } -func (s *testSampleSuite) TestDistributedWeightedSampling(c *C) { +func TestDistributedWeightedSampling(t *testing.T) { + t.Parallel() sampleNum := int64(10) rowNum := 100 loopCnt := 1500 batch := 5 - sets := s.recordSetForDistributedSamplingTest(rowNum, batch) + sets := recordSetForDistributedSamplingTest(rowNum, batch) sc := mock.NewContext().GetSessionVars().StmtCtx // The loop which is commented out is used for stability test. // This test can run 800 times in a row without any failure. @@ -232,9 +121,9 @@ func (s *testSampleSuite) TestDistributedWeightedSampling(c *C) { Rng: rand.New(rand.NewSource(time.Now().UnixNano())), } collector, err := builder.Collect() - c.Assert(err, IsNil) + require.NoError(t, err) rootRowCollector.MergeCollector(collector) - c.Assert(sets[i].Close(), IsNil) + require.Nil(t, sets[i].Close()) } for _, sample := range rootRowCollector.Samples { itemCnt[sample.Columns[0].GetInt64()]++ @@ -244,44 +133,44 @@ func (s *testSampleSuite) TestDistributedWeightedSampling(c *C) { delta := 0.5 for _, cnt := range itemCnt { if float64(cnt) < expFrequency/(1+delta) || float64(cnt) > expFrequency*(1+delta) { - c.Assert(false, IsTrue, Commentf("the frequency %v is exceed the Chernoff Bound", cnt)) + require.Truef(t, false, "the frequency %v is exceed the Chernoff Bound", cnt) } } - // } } -func (s *testSampleSuite) TestBuildStatsOnRowSample(c *C) { +func TestBuildStatsOnRowSample(t *testing.T) { + t.Parallel() ctx := mock.NewContext() sketch := NewFMSketch(1000) data := make([]*SampleItem, 0, 8) for i := 1; i <= 1000; i++ { d := types.NewIntDatum(int64(i)) err := sketch.InsertValue(ctx.GetSessionVars().StmtCtx, d) - c.Assert(err, IsNil) + require.NoError(t, err) data = append(data, &SampleItem{Value: d}) } for i := 1; i < 10; i++ { d := types.NewIntDatum(int64(2)) err := sketch.InsertValue(ctx.GetSessionVars().StmtCtx, d) - c.Assert(err, IsNil) + require.NoError(t, err) data = append(data, &SampleItem{Value: d}) } for i := 1; i < 7; i++ { d := types.NewIntDatum(int64(4)) err := sketch.InsertValue(ctx.GetSessionVars().StmtCtx, d) - c.Assert(err, IsNil) + require.NoError(t, err) data = append(data, &SampleItem{Value: d}) } for i := 1; i < 5; i++ { d := types.NewIntDatum(int64(7)) err := sketch.InsertValue(ctx.GetSessionVars().StmtCtx, d) - c.Assert(err, IsNil) + require.NoError(t, err) data = append(data, &SampleItem{Value: d}) } for i := 1; i < 3; i++ { d := types.NewIntDatum(int64(11)) err := sketch.InsertValue(ctx.GetSessionVars().StmtCtx, d) - c.Assert(err, IsNil) + require.NoError(t, err) data = append(data, &SampleItem{Value: d}) } collector := &SampleCollector{ @@ -293,16 +182,14 @@ func (s *testSampleSuite) TestBuildStatsOnRowSample(c *C) { } tp := types.NewFieldType(mysql.TypeLonglong) hist, topN, err := BuildHistAndTopN(ctx, 5, 4, 1, collector, tp, true) - c.Assert(err, IsNil, Commentf("%+v", err)) + require.Nilf(t, err, "%+v", err) topNStr, err := topN.DecodedString(ctx, []byte{tp.Tp}) - c.Assert(err, IsNil) - c.Assert(topNStr, Equals, "TopN{length: 4, [(2, 10), (4, 7), (7, 5), (11, 3)]}") - c.Assert(hist.ToString(0), Equals, "column:1 ndv:1000 totColSize:8168\n"+ + require.NoError(t, err) + require.Equal(t, "TopN{length: 4, [(2, 10), (4, 7), (7, 5), (11, 3)]}", topNStr) + require.Equal(t, "column:1 ndv:1000 totColSize:8168\n"+ "num: 200 lower_bound: 1 upper_bound: 204 repeats: 1 ndv: 0\n"+ "num: 200 lower_bound: 205 upper_bound: 404 repeats: 1 ndv: 0\n"+ "num: 200 lower_bound: 405 upper_bound: 604 repeats: 1 ndv: 0\n"+ "num: 200 lower_bound: 605 upper_bound: 804 repeats: 1 ndv: 0\n"+ - "num: 196 lower_bound: 805 upper_bound: 1000 repeats: 1 ndv: 0", - ) - + "num: 196 lower_bound: 805 upper_bound: 1000 repeats: 1 ndv: 0", hist.ToString(0)) } diff --git a/table/tables/tables_test.go b/table/tables/tables_test.go index 9d63e72e309e5..0dc813767907b 100644 --- a/table/tables/tables_test.go +++ b/table/tables/tables_test.go @@ -165,7 +165,7 @@ func TestBasic(t *testing.T) { alc := tb.Allocators(nil).Get(autoid.RowIDAllocType) require.NotNil(t, alc) - err = alc.Rebase(0, false) + err = alc.Rebase(context.Background(), 0, false) require.NoError(t, err) } @@ -419,7 +419,7 @@ func TestTableFromMeta(t *testing.T) { require.NoError(t, err) maxID := 1<<(64-15-1) - 1 - err = tb.Allocators(tk.Session()).Get(autoid.RowIDAllocType).Rebase(int64(maxID), false) + err = tb.Allocators(tk.Session()).Get(autoid.RowIDAllocType).Rebase(context.Background(), int64(maxID), false) require.NoError(t, err) _, err = tables.AllocHandle(context.Background(), tk.Session(), tb) diff --git a/types/datum.go b/types/datum.go index 4368a6631e6f1..91bb015f3182c 100644 --- a/types/datum.go +++ b/types/datum.go @@ -2066,7 +2066,11 @@ func DatumsToString(datums []Datum, handleSpecialValue bool) (string, error) { if err != nil { return "", errors.Trace(err) } - strs = append(strs, str) + if datum.Kind() == KindString { + strs = append(strs, fmt.Sprintf("%q", str)) + } else { + strs = append(strs, str) + } } size := len(datums) if size > 1 { diff --git a/util/execdetails/execdetails.go b/util/execdetails/execdetails.go index 4c3b1d9ddfd51..4265145c2d66a 100644 --- a/util/execdetails/execdetails.go +++ b/util/execdetails/execdetails.go @@ -415,16 +415,18 @@ const ( TpSelectResultRuntimeStats // TpInsertRuntimeStat is the tp for InsertRuntimeStat TpInsertRuntimeStat - // TpIndexLookUpRunTimeStats is the tp for TpIndexLookUpRunTimeStats + // TpIndexLookUpRunTimeStats is the tp for IndexLookUpRunTimeStats TpIndexLookUpRunTimeStats - // TpSlowQueryRuntimeStat is the tp for TpSlowQueryRuntimeStat + // TpSlowQueryRuntimeStat is the tp for SlowQueryRuntimeStat TpSlowQueryRuntimeStat // TpHashAggRuntimeStat is the tp for HashAggRuntimeStat TpHashAggRuntimeStat - // TpIndexMergeRunTimeStats is the tp for TpIndexMergeRunTimeStats + // TpIndexMergeRunTimeStats is the tp for IndexMergeRunTimeStats TpIndexMergeRunTimeStats - // TpBasicCopRunTimeStats is the tp for TpBasicCopRunTimeStats + // TpBasicCopRunTimeStats is the tp for BasicCopRunTimeStats TpBasicCopRunTimeStats + // TpUpdateRuntimeStats is the tp for UpdateRuntimeStats + TpUpdateRuntimeStats ) // RuntimeStats is used to express the executor runtime information. @@ -761,6 +763,7 @@ func (e *RuntimeStatsWithConcurrencyInfo) Merge(_ RuntimeStats) { // RuntimeStatsWithCommit is the RuntimeStats with commit detail. type RuntimeStatsWithCommit struct { Commit *util.CommitDetails + TxnCnt int LockKeys *util.LockKeysDetails } @@ -769,12 +772,27 @@ func (e *RuntimeStatsWithCommit) Tp() int { return TpRuntimeStatsWithCommit } +// MergeCommitDetails merges the commit details. +func (e *RuntimeStatsWithCommit) MergeCommitDetails(detail *util.CommitDetails) { + if detail == nil { + return + } + if e.Commit == nil { + e.Commit = detail + e.TxnCnt = 1 + return + } + e.Commit.Merge(detail) + e.TxnCnt++ +} + // Merge implements the RuntimeStats interface. func (e *RuntimeStatsWithCommit) Merge(rs RuntimeStats) { tmp, ok := rs.(*RuntimeStatsWithCommit) if !ok { return } + e.TxnCnt += tmp.TxnCnt if tmp.Commit != nil { if e.Commit == nil { e.Commit = &util.CommitDetails{} @@ -792,7 +810,9 @@ func (e *RuntimeStatsWithCommit) Merge(rs RuntimeStats) { // Clone implements the RuntimeStats interface. func (e *RuntimeStatsWithCommit) Clone() RuntimeStats { - newRs := RuntimeStatsWithCommit{} + newRs := RuntimeStatsWithCommit{ + TxnCnt: e.TxnCnt, + } if e.Commit != nil { newRs.Commit = e.Commit.Clone() } @@ -807,6 +827,12 @@ func (e *RuntimeStatsWithCommit) String() string { buf := bytes.NewBuffer(make([]byte, 0, 32)) if e.Commit != nil { buf.WriteString("commit_txn: {") + // Only print out when there are more than 1 transaction. + if e.TxnCnt > 1 { + buf.WriteString("count: ") + buf.WriteString(strconv.Itoa(e.TxnCnt)) + buf.WriteString(", ") + } if e.Commit.PrewriteTime > 0 { buf.WriteString("prewrite:") buf.WriteString(FormatDuration(e.Commit.PrewriteTime))