From 37dbe78bd3144444fe88f3ec01feab3b88e5e060 Mon Sep 17 00:00:00 2001 From: Chao Wang Date: Thu, 8 Jul 2021 16:38:01 +0800 Subject: [PATCH 1/5] *: finish insert/batch_get for local temporary table --- executor/point_get.go | 6 ++ go.mod | 2 + session/session.go | 80 +++++++++++++++++++++- session/session_test.go | 95 +++++++++++++++++++++++++++ sessionctx/variable/session.go | 41 ++++++++++++ store/driver/txn/txn_driver.go | 2 +- store/driver/txn/unionstore_driver.go | 2 +- table/tables/index.go | 5 +- table/tables/tables.go | 13 +++- util/tableutil/tableutil.go | 2 + 10 files changed, 243 insertions(+), 5 deletions(-) diff --git a/executor/point_get.go b/executor/point_get.go index 3b6e05f589315..83678d29330c0 100644 --- a/executor/point_get.go +++ b/executor/point_get.go @@ -390,6 +390,12 @@ func (e *PointGetExecutor) get(ctx context.Context, key kv.Key) ([]byte, error) if e.tblInfo.TempTableType == model.TempTableGlobal { return nil, nil } + + // Local temporary table always get snapshot value from session + if e.tblInfo.TempTableType == model.TempTableLocal { + return e.ctx.GetSessionVars().GetTemporaryTableSnapshotValue(ctx, key) + } + lock := e.tblInfo.Lock if lock != nil && (lock.Tp == model.TableLockRead || lock.Tp == model.TableLockReadOnly) { if e.ctx.GetSessionVars().EnablePointGetCache { diff --git a/go.mod b/go.mod index 8a354b9b08e15..6a1f0c7a8e77e 100644 --- a/go.mod +++ b/go.mod @@ -82,3 +82,5 @@ require ( ) go 1.16 + +replace github.com/tikv/client-go/v2 => /Users/wangchao/Code/pingcap/client-go diff --git a/session/session.go b/session/session.go index 25fae80a2b7d1..d3e9b15e673b9 100644 --- a/session/session.go +++ b/session/session.go @@ -18,6 +18,7 @@ package session import ( + "bytes" "context" "crypto/tls" "encoding/json" @@ -41,6 +42,7 @@ import ( "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" + transaction "github.com/pingcap/tidb/store/driver/txn" "github.com/pingcap/tidb/util/topsql" "github.com/pingcap/tipb/go-binlog" "go.uber.org/zap" @@ -541,7 +543,83 @@ func (s *session) doCommit(ctx context.Context) error { s.txn.SetOption(kv.KVFilter, temporaryTableKVFilter(tables)) } - return s.txn.Commit(tikvutil.SetSessionID(ctx, sessVars.ConnectionID)) + return s.commitTxnWithTemporaryData(tikvutil.SetSessionID(ctx, sessVars.ConnectionID), &s.txn) +} + +func (s *session) commitTxnWithTemporaryData(ctx context.Context, txn kv.Transaction) error { + txnTempTables := s.sessionVars.TxnCtx.TemporaryTables + if len(txnTempTables) == 0 { + return txn.Commit(ctx) + } + + sessionData := s.sessionVars.TemporaryTableData + var stage kv.StagingHandle + for tblID, tbl := range txnTempTables { + if !tbl.GetModified() { + continue + } + + if tbl.GetMeta().TempTableType != model.TempTableLocal { + continue + } + + if sessionData == nil { + sessionData = transaction.NewMemBuffer(tikv.NewMemDB()) + } + + if stage == kv.InvalidStagingHandle { + stage = sessionData.Staging() + } + + tblPrefix := tablecodec.EncodeTablePrefix(tblID) + endKey := tablecodec.EncodeTablePrefix(tblID + 1) + + txnMemBuffer := s.txn.GetMemBuffer() + iter, err := txnMemBuffer.Iter(tblPrefix, endKey) + if err != nil { + return err + } + + for iter.Valid() { + key := iter.Key() + if !bytes.HasPrefix(key, tblPrefix) { + break + } + + value := iter.Value() + if len(value) == 0 { + err = sessionData.Delete(key) + } else { + err = sessionData.Set(key, iter.Value()) + } + + if err != nil { + return err + } + + err = txnMemBuffer.Delete(key) + if err != nil { + return err + } + + err = iter.Next() + if err != nil { + return err + } + } + } + + err := txn.Commit(ctx) + if stage != kv.InvalidStagingHandle { + if err != nil { + sessionData.Cleanup(stage) + } else { + sessionData.Release(stage) + s.sessionVars.TemporaryTableData = sessionData + } + } + + return err } type temporaryTableKVFilter map[int64]tableutil.TempTable diff --git a/session/session_test.go b/session/session_test.go index 1b96824bd06b2..06c3a7042d568 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -4859,3 +4859,98 @@ func (s *testSessionSuite) TestAuthPluginForUser(c *C) { c.Assert(err, IsNil) c.Assert(plugin, Equals, "") } + +func (s *testSessionSuite) TestLocalTemporaryTableInsert(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("set @@tidb_enable_noop_functions=1") + tk.MustExec("use test") + tk.MustExec("create temporary table tmp1 (id int primary key auto_increment, u int unique, v int)") + tk.MustExec("insert into tmp1 (u, v) values(11, 101)") + tk.MustExec("insert into tmp1 (u, v) values(12, 102)") + tk.MustExec("insert into tmp1 values(3, 13, 102)") + + checkRecordOneTwoThreeAndNonExist := func() { + tk.MustQuery("select * from tmp1 where id=1").Check(testkit.Rows("1 11 101")) + tk.MustQuery("select * from tmp1 where id=2").Check(testkit.Rows("2 12 102")) + tk.MustQuery("select * from tmp1 where id=3").Check(testkit.Rows("3 13 102")) + tk.MustQuery("select * from tmp1 where id=99").Check(testkit.Rows()) + } + + // inserted records exist + checkRecordOneTwoThreeAndNonExist() + + // insert dup records out txn must be error + _, err := tk.Exec("insert into tmp1 values(1, 999, 9999)") + c.Assert(kv.ErrKeyExists.Equal(err), IsTrue) + checkRecordOneTwoThreeAndNonExist() + + _, err = tk.Exec("insert into tmp1 values(99, 11, 999)") + c.Assert(kv.ErrKeyExists.Equal(err), IsTrue) + checkRecordOneTwoThreeAndNonExist() + + // insert dup records in txn must be error + tk.MustExec("begin") + _, err = tk.Exec("insert into tmp1 values(1, 999, 9999)") + c.Assert(kv.ErrKeyExists.Equal(err), IsTrue) + checkRecordOneTwoThreeAndNonExist() + + _, err = tk.Exec("insert into tmp1 values(99, 11, 9999)") + c.Assert(kv.ErrKeyExists.Equal(err), IsTrue) + checkRecordOneTwoThreeAndNonExist() + + tk.MustExec("insert into tmp1 values(4, 14, 104)") + tk.MustQuery("select * from tmp1 where id=4").Check(testkit.Rows("4 14 104")) + + _, err = tk.Exec("insert into tmp1 values(4, 999, 9999)") + c.Assert(kv.ErrKeyExists.Equal(err), IsTrue) + + _, err = tk.Exec("insert into tmp1 values(99, 14, 9999)") + c.Assert(kv.ErrKeyExists.Equal(err), IsTrue) + + checkRecordOneTwoThreeAndNonExist() + tk.MustExec("commit") + + // check committed insert works + checkRecordOneTwoThreeAndNonExist() + tk.MustQuery("select * from tmp1 where id=4").Check(testkit.Rows("4 14 104")) + + // check rollback works + tk.MustExec("begin") + tk.MustExec("insert into tmp1 values(5, 15, 105)") + tk.MustQuery("select * from tmp1 where id=5").Check(testkit.Rows("5 15 105")) + tk.MustExec("rollback") + tk.MustQuery("select * from tmp1 where id=5").Check(testkit.Rows()) +} + +func (s *testSessionSuite) TestLocalTemporaryTablePointGet(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("set @@tidb_enable_noop_functions=1") + tk.MustExec("use test") + tk.MustExec("create temporary table tmp1 (id int primary key auto_increment, u int unique, v int)") + tk.MustExec("insert into tmp1 values(1, 11, 101)") + tk.MustExec("insert into tmp1 values(2, 12, 102)") + + // check point get out transaction + tk.MustQuery("select * from tmp1 where id=1").Check(testkit.Rows("1 11 101")) + tk.MustQuery("select * from tmp1 where u=11").Check(testkit.Rows("1 11 101")) + tk.MustQuery("select * from tmp1 where id=2").Check(testkit.Rows("2 12 102")) + tk.MustQuery("select * from tmp1 where u=12").Check(testkit.Rows("2 12 102")) + + // check point get in transaction + tk.MustExec("begin") + tk.MustQuery("select * from tmp1 where id=1").Check(testkit.Rows("1 11 101")) + tk.MustQuery("select * from tmp1 where u=11").Check(testkit.Rows("1 11 101")) + tk.MustQuery("select * from tmp1 where id=2").Check(testkit.Rows("2 12 102")) + tk.MustQuery("select * from tmp1 where u=12").Check(testkit.Rows("2 12 102")) + tk.MustExec("insert into tmp1 values(3, 13, 103)") + tk.MustQuery("select * from tmp1 where id=3").Check(testkit.Rows("3 13 103")) + tk.MustQuery("select * from tmp1 where u=13").Check(testkit.Rows("3 13 103")) + tk.MustExec("update tmp1 set v=999 where id=2") + tk.MustQuery("select * from tmp1 where id=2").Check(testkit.Rows("2 12 999")) + tk.MustExec("commit") + + // check point get after transaction + tk.MustQuery("select * from tmp1 where id=3").Check(testkit.Rows("3 13 103")) + tk.MustQuery("select * from tmp1 where u=13").Check(testkit.Rows("3 13 103")) + tk.MustQuery("select * from tmp1 where id=2").Check(testkit.Rows("2 12 999")) +} diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 7166f6ec44bb6..8a0a99f42c1e6 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -15,6 +15,7 @@ package variable import ( "bytes" + "context" "crypto/tls" "encoding/binary" "fmt" @@ -859,6 +860,9 @@ type SessionVars struct { // LocalTemporaryTables is *infoschema.LocalTemporaryTables, use interface to avoid circle dependency. // It's nil if there is no local temporary table. LocalTemporaryTables interface{} + + // TemporaryTableData stores committed kv values for temporary table for current session. + TemporaryTableData kv.MemBuffer } // AllocMPPTaskID allocates task id for mpp tasks. It will reset the task id if the query's @@ -2199,3 +2203,40 @@ func (s *SessionVars) GetSeekFactor(tbl *model.TableInfo) float64 { } return s.seekFactor } + +// GetTemporaryTableSnapshotValue get temporary table value from session +func (s *SessionVars) GetTemporaryTableSnapshotValue(ctx context.Context, key kv.Key) ([]byte, error) { + memData := s.TemporaryTableData + if memData == nil { + return nil, kv.ErrNotExist + } + + v, err := memData.Get(ctx, key) + if err != nil { + return v, err + } + + if len(v) == 0 { + return nil, kv.ErrNotExist + } + + return v, nil +} + +// GetTemporaryTableTxnValue returns a kv.Getter to fetch temporary table data in txn +func (s *SessionVars) GetTemporaryTableTxnValue(ctx context.Context, txn kv.Transaction, key kv.Key) ([]byte, error) { + v, err := txn.GetMemBuffer().Get(ctx, key) + if err == nil { + if len(v) == 0 { + return nil, kv.ErrNotExist + } + + return v, nil + } + + if !kv.IsErrNotFound(err) { + return v, err + } + + return s.GetTemporaryTableSnapshotValue(ctx, key) +} diff --git a/store/driver/txn/txn_driver.go b/store/driver/txn/txn_driver.go index 3c0500e9934ff..98841cf64d597 100644 --- a/store/driver/txn/txn_driver.go +++ b/store/driver/txn/txn_driver.go @@ -118,7 +118,7 @@ func (txn *tikvTxn) Set(k kv.Key, v []byte) error { } func (txn *tikvTxn) GetMemBuffer() kv.MemBuffer { - return newMemBuffer(txn.KVTxn.GetMemBuffer()) + return NewMemBuffer(txn.KVTxn.GetMemBuffer()) } func (txn *tikvTxn) SetOption(opt int, val interface{}) { diff --git a/store/driver/txn/unionstore_driver.go b/store/driver/txn/unionstore_driver.go index d58eca5fd552d..aba35b7b69b35 100644 --- a/store/driver/txn/unionstore_driver.go +++ b/store/driver/txn/unionstore_driver.go @@ -27,7 +27,7 @@ type memBuffer struct { *tikv.MemDB } -func newMemBuffer(m *tikv.MemDB) kv.MemBuffer { +func NewMemBuffer(m *tikv.MemDB) kv.MemBuffer { if m == nil { return nil } diff --git a/table/tables/index.go b/table/tables/index.go index aef03d0590aaa..ae0eca1339482 100644 --- a/table/tables/index.go +++ b/table/tables/index.go @@ -199,7 +199,10 @@ func (c *index) Create(sctx sessionctx.Context, txn kv.Transaction, indexedValue } var value []byte - if sctx.GetSessionVars().LazyCheckKeyNotExists() { + if c.tblInfo.TempTableType != model.TempTableNone { + // Always check key for temporary table because it does not write to TiKV + value, err = sctx.GetSessionVars().GetTemporaryTableTxnValue(ctx, txn, key) + } else if sctx.GetSessionVars().LazyCheckKeyNotExists() { value, err = txn.GetMemBuffer().Get(ctx, key) } else { value, err = txn.Get(ctx, key) diff --git a/table/tables/tables.go b/table/tables/tables.go index 5ef1707ad8010..bb89cf7f7a26f 100644 --- a/table/tables/tables.go +++ b/table/tables/tables.go @@ -770,7 +770,10 @@ func (t *TableCommon) AddRecord(sctx sessionctx.Context, r []types.Datum, opts . var setPresume bool skipCheck := sctx.GetSessionVars().StmtCtx.BatchCheck if (t.meta.IsCommonHandle || t.meta.PKIsHandle) && !skipCheck && !opt.SkipHandleCheck { - if sctx.GetSessionVars().LazyCheckKeyNotExists() { + if t.meta.TempTableType != model.TempTableNone { + // Always check key for temporary table because it does not write to TiKV + _, err = sctx.GetSessionVars().GetTemporaryTableTxnValue(ctx, txn, key) + } else if sctx.GetSessionVars().LazyCheckKeyNotExists() { var v []byte v, err = txn.GetMemBuffer().Get(ctx, key) if err != nil { @@ -1827,6 +1830,8 @@ type TemporaryTable struct { autoIDAllocator autoid.Allocator // Table size. size int64 + + meta *model.TableInfo } // TempTableFromMeta builds a TempTable from model.TableInfo. @@ -1835,6 +1840,7 @@ func TempTableFromMeta(tblInfo *model.TableInfo) tableutil.TempTable { modified: false, stats: statistics.PseudoTable(tblInfo), autoIDAllocator: autoid.NewAllocatorFromTempTblInfo(tblInfo), + meta: tblInfo, } } @@ -1867,3 +1873,8 @@ func (t *TemporaryTable) GetSize() int64 { func (t *TemporaryTable) SetSize(v int64) { t.size = v } + +// GetMeta gets the table meta. +func (t *TemporaryTable) GetMeta() *model.TableInfo { + return t.meta +} diff --git a/util/tableutil/tableutil.go b/util/tableutil/tableutil.go index bf5d7caac2732..446cd170333aa 100644 --- a/util/tableutil/tableutil.go +++ b/util/tableutil/tableutil.go @@ -36,6 +36,8 @@ type TempTable interface { GetSize() int64 SetSize(int64) + + GetMeta() *model.TableInfo } // TempTableFromMeta builds a TempTable from *model.TableInfo. From e1292d51a728b89efbdd3702e83adbb7f424164b Mon Sep 17 00:00:00 2001 From: Chao Wang Date: Thu, 8 Jul 2021 17:04:42 +0800 Subject: [PATCH 2/5] get membuf from txn --- go.mod | 2 -- session/session.go | 9 +++++++-- store/driver/txn/txn_driver.go | 2 +- store/driver/txn/unionstore_driver.go | 2 +- 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index 6a1f0c7a8e77e..8a354b9b08e15 100644 --- a/go.mod +++ b/go.mod @@ -82,5 +82,3 @@ require ( ) go 1.16 - -replace github.com/tikv/client-go/v2 => /Users/wangchao/Code/pingcap/client-go diff --git a/session/session.go b/session/session.go index d3e9b15e673b9..9fbb8ef65497a 100644 --- a/session/session.go +++ b/session/session.go @@ -42,7 +42,6 @@ import ( "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" - transaction "github.com/pingcap/tidb/store/driver/txn" "github.com/pingcap/tidb/util/topsql" "github.com/pingcap/tipb/go-binlog" "go.uber.org/zap" @@ -564,7 +563,13 @@ func (s *session) commitTxnWithTemporaryData(ctx context.Context, txn kv.Transac } if sessionData == nil { - sessionData = transaction.NewMemBuffer(tikv.NewMemDB()) + // Create this txn just for getting a MemBuffer. It's a little tricky + bufferTxn, err := s.store.BeginWithOption(tikv.DefaultStartTSOption().SetStartTS(0)) + if err != nil { + return err + } + + sessionData = bufferTxn.GetMemBuffer() } if stage == kv.InvalidStagingHandle { diff --git a/store/driver/txn/txn_driver.go b/store/driver/txn/txn_driver.go index 98841cf64d597..3c0500e9934ff 100644 --- a/store/driver/txn/txn_driver.go +++ b/store/driver/txn/txn_driver.go @@ -118,7 +118,7 @@ func (txn *tikvTxn) Set(k kv.Key, v []byte) error { } func (txn *tikvTxn) GetMemBuffer() kv.MemBuffer { - return NewMemBuffer(txn.KVTxn.GetMemBuffer()) + return newMemBuffer(txn.KVTxn.GetMemBuffer()) } func (txn *tikvTxn) SetOption(opt int, val interface{}) { diff --git a/store/driver/txn/unionstore_driver.go b/store/driver/txn/unionstore_driver.go index aba35b7b69b35..d58eca5fd552d 100644 --- a/store/driver/txn/unionstore_driver.go +++ b/store/driver/txn/unionstore_driver.go @@ -27,7 +27,7 @@ type memBuffer struct { *tikv.MemDB } -func NewMemBuffer(m *tikv.MemDB) kv.MemBuffer { +func newMemBuffer(m *tikv.MemDB) kv.MemBuffer { if m == nil { return nil } From d2b4eac635e08d758be94074857be65fdfa0df97 Mon Sep 17 00:00:00 2001 From: Chao Wang Date: Thu, 8 Jul 2021 22:43:53 +0800 Subject: [PATCH 3/5] delete remove to avoid data race --- session/session.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/session/session.go b/session/session.go index 9fbb8ef65497a..9952522b8166d 100644 --- a/session/session.go +++ b/session/session.go @@ -602,11 +602,6 @@ func (s *session) commitTxnWithTemporaryData(ctx context.Context, txn kv.Transac return err } - err = txnMemBuffer.Delete(key) - if err != nil { - return err - } - err = iter.Next() if err != nil { return err From b7a15c85d26024f15a62241ff52d1f36c1c3f54b Mon Sep 17 00:00:00 2001 From: Chao Wang Date: Fri, 9 Jul 2021 17:41:06 +0800 Subject: [PATCH 4/5] address comment --- session/session.go | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/session/session.go b/session/session.go index 8093c2bf614e4..e701f3ad3693a 100644 --- a/session/session.go +++ b/session/session.go @@ -553,6 +553,14 @@ func (s *session) commitTxnWithTemporaryData(ctx context.Context, txn kv.Transac sessionData := s.sessionVars.TemporaryTableData var stage kv.StagingHandle + + defer func() { + // stage != kv.InvalidStagingHandle means error occurs, we need to cleanup sessionData + if stage != kv.InvalidStagingHandle { + sessionData.Cleanup(stage) + } + }() + for tblID, tbl := range txnTempTables { if !tbl.GetModified() { continue @@ -610,16 +618,15 @@ func (s *session) commitTxnWithTemporaryData(ctx context.Context, txn kv.Transac } err := txn.Commit(ctx) - if stage != kv.InvalidStagingHandle { - if err != nil { - sessionData.Cleanup(stage) - } else { - sessionData.Release(stage) - s.sessionVars.TemporaryTableData = sessionData - } + if err != nil { + return err } - return err + sessionData.Release(stage) + s.sessionVars.TemporaryTableData = sessionData + stage = kv.InvalidStagingHandle + + return nil } type temporaryTableKVFilter map[int64]tableutil.TempTable From 6e400217163eb9910333f772fb6239370ba77373 Mon Sep 17 00:00:00 2001 From: Chao Wang Date: Fri, 9 Jul 2021 17:51:12 +0800 Subject: [PATCH 5/5] fix panic --- session/session.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/session/session.go b/session/session.go index e701f3ad3693a..29c2fdc22be13 100644 --- a/session/session.go +++ b/session/session.go @@ -622,9 +622,11 @@ func (s *session) commitTxnWithTemporaryData(ctx context.Context, txn kv.Transac return err } - sessionData.Release(stage) - s.sessionVars.TemporaryTableData = sessionData - stage = kv.InvalidStagingHandle + if stage != kv.InvalidStagingHandle { + sessionData.Release(stage) + s.sessionVars.TemporaryTableData = sessionData + stage = kv.InvalidStagingHandle + } return nil }