diff --git a/executor/point_get.go b/executor/point_get.go index 3b6e05f589315..83678d29330c0 100644 --- a/executor/point_get.go +++ b/executor/point_get.go @@ -390,6 +390,12 @@ func (e *PointGetExecutor) get(ctx context.Context, key kv.Key) ([]byte, error) if e.tblInfo.TempTableType == model.TempTableGlobal { return nil, nil } + + // Local temporary table always get snapshot value from session + if e.tblInfo.TempTableType == model.TempTableLocal { + return e.ctx.GetSessionVars().GetTemporaryTableSnapshotValue(ctx, key) + } + lock := e.tblInfo.Lock if lock != nil && (lock.Tp == model.TableLockRead || lock.Tp == model.TableLockReadOnly) { if e.ctx.GetSessionVars().EnablePointGetCache { diff --git a/session/session.go b/session/session.go index 9b7bca255740d..29c2fdc22be13 100644 --- a/session/session.go +++ b/session/session.go @@ -18,6 +18,7 @@ package session import ( + "bytes" "context" "crypto/tls" "encoding/json" @@ -541,7 +542,93 @@ func (s *session) doCommit(ctx context.Context) error { s.txn.SetOption(kv.KVFilter, temporaryTableKVFilter(tables)) } - return s.txn.Commit(tikvutil.SetSessionID(ctx, sessVars.ConnectionID)) + return s.commitTxnWithTemporaryData(tikvutil.SetSessionID(ctx, sessVars.ConnectionID), &s.txn) +} + +func (s *session) commitTxnWithTemporaryData(ctx context.Context, txn kv.Transaction) error { + txnTempTables := s.sessionVars.TxnCtx.TemporaryTables + if len(txnTempTables) == 0 { + return txn.Commit(ctx) + } + + sessionData := s.sessionVars.TemporaryTableData + var stage kv.StagingHandle + + defer func() { + // stage != kv.InvalidStagingHandle means error occurs, we need to cleanup sessionData + if stage != kv.InvalidStagingHandle { + sessionData.Cleanup(stage) + } + }() + + for tblID, tbl := range txnTempTables { + if !tbl.GetModified() { + continue + } + + if tbl.GetMeta().TempTableType != model.TempTableLocal { + continue + } + + if sessionData == nil { + // Create this txn just for getting a MemBuffer. It's a little tricky + bufferTxn, err := s.store.BeginWithOption(tikv.DefaultStartTSOption().SetStartTS(0)) + if err != nil { + return err + } + + sessionData = bufferTxn.GetMemBuffer() + } + + if stage == kv.InvalidStagingHandle { + stage = sessionData.Staging() + } + + tblPrefix := tablecodec.EncodeTablePrefix(tblID) + endKey := tablecodec.EncodeTablePrefix(tblID + 1) + + txnMemBuffer := s.txn.GetMemBuffer() + iter, err := txnMemBuffer.Iter(tblPrefix, endKey) + if err != nil { + return err + } + + for iter.Valid() { + key := iter.Key() + if !bytes.HasPrefix(key, tblPrefix) { + break + } + + value := iter.Value() + if len(value) == 0 { + err = sessionData.Delete(key) + } else { + err = sessionData.Set(key, iter.Value()) + } + + if err != nil { + return err + } + + err = iter.Next() + if err != nil { + return err + } + } + } + + err := txn.Commit(ctx) + if err != nil { + return err + } + + if stage != kv.InvalidStagingHandle { + sessionData.Release(stage) + s.sessionVars.TemporaryTableData = sessionData + stage = kv.InvalidStagingHandle + } + + return nil } type temporaryTableKVFilter map[int64]tableutil.TempTable diff --git a/session/session_test.go b/session/session_test.go index 67f68f27764ad..37861878fc157 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -4859,3 +4859,98 @@ func (s *testSessionSuite) TestAuthPluginForUser(c *C) { c.Assert(err, IsNil) c.Assert(plugin, Equals, "") } + +func (s *testSessionSuite) TestLocalTemporaryTableInsert(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("set @@tidb_enable_noop_functions=1") + tk.MustExec("use test") + tk.MustExec("create temporary table tmp1 (id int primary key auto_increment, u int unique, v int)") + tk.MustExec("insert into tmp1 (u, v) values(11, 101)") + tk.MustExec("insert into tmp1 (u, v) values(12, 102)") + tk.MustExec("insert into tmp1 values(3, 13, 102)") + + checkRecordOneTwoThreeAndNonExist := func() { + tk.MustQuery("select * from tmp1 where id=1").Check(testkit.Rows("1 11 101")) + tk.MustQuery("select * from tmp1 where id=2").Check(testkit.Rows("2 12 102")) + tk.MustQuery("select * from tmp1 where id=3").Check(testkit.Rows("3 13 102")) + tk.MustQuery("select * from tmp1 where id=99").Check(testkit.Rows()) + } + + // inserted records exist + checkRecordOneTwoThreeAndNonExist() + + // insert dup records out txn must be error + _, err := tk.Exec("insert into tmp1 values(1, 999, 9999)") + c.Assert(kv.ErrKeyExists.Equal(err), IsTrue) + checkRecordOneTwoThreeAndNonExist() + + _, err = tk.Exec("insert into tmp1 values(99, 11, 999)") + c.Assert(kv.ErrKeyExists.Equal(err), IsTrue) + checkRecordOneTwoThreeAndNonExist() + + // insert dup records in txn must be error + tk.MustExec("begin") + _, err = tk.Exec("insert into tmp1 values(1, 999, 9999)") + c.Assert(kv.ErrKeyExists.Equal(err), IsTrue) + checkRecordOneTwoThreeAndNonExist() + + _, err = tk.Exec("insert into tmp1 values(99, 11, 9999)") + c.Assert(kv.ErrKeyExists.Equal(err), IsTrue) + checkRecordOneTwoThreeAndNonExist() + + tk.MustExec("insert into tmp1 values(4, 14, 104)") + tk.MustQuery("select * from tmp1 where id=4").Check(testkit.Rows("4 14 104")) + + _, err = tk.Exec("insert into tmp1 values(4, 999, 9999)") + c.Assert(kv.ErrKeyExists.Equal(err), IsTrue) + + _, err = tk.Exec("insert into tmp1 values(99, 14, 9999)") + c.Assert(kv.ErrKeyExists.Equal(err), IsTrue) + + checkRecordOneTwoThreeAndNonExist() + tk.MustExec("commit") + + // check committed insert works + checkRecordOneTwoThreeAndNonExist() + tk.MustQuery("select * from tmp1 where id=4").Check(testkit.Rows("4 14 104")) + + // check rollback works + tk.MustExec("begin") + tk.MustExec("insert into tmp1 values(5, 15, 105)") + tk.MustQuery("select * from tmp1 where id=5").Check(testkit.Rows("5 15 105")) + tk.MustExec("rollback") + tk.MustQuery("select * from tmp1 where id=5").Check(testkit.Rows()) +} + +func (s *testSessionSuite) TestLocalTemporaryTablePointGet(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("set @@tidb_enable_noop_functions=1") + tk.MustExec("use test") + tk.MustExec("create temporary table tmp1 (id int primary key auto_increment, u int unique, v int)") + tk.MustExec("insert into tmp1 values(1, 11, 101)") + tk.MustExec("insert into tmp1 values(2, 12, 102)") + + // check point get out transaction + tk.MustQuery("select * from tmp1 where id=1").Check(testkit.Rows("1 11 101")) + tk.MustQuery("select * from tmp1 where u=11").Check(testkit.Rows("1 11 101")) + tk.MustQuery("select * from tmp1 where id=2").Check(testkit.Rows("2 12 102")) + tk.MustQuery("select * from tmp1 where u=12").Check(testkit.Rows("2 12 102")) + + // check point get in transaction + tk.MustExec("begin") + tk.MustQuery("select * from tmp1 where id=1").Check(testkit.Rows("1 11 101")) + tk.MustQuery("select * from tmp1 where u=11").Check(testkit.Rows("1 11 101")) + tk.MustQuery("select * from tmp1 where id=2").Check(testkit.Rows("2 12 102")) + tk.MustQuery("select * from tmp1 where u=12").Check(testkit.Rows("2 12 102")) + tk.MustExec("insert into tmp1 values(3, 13, 103)") + tk.MustQuery("select * from tmp1 where id=3").Check(testkit.Rows("3 13 103")) + tk.MustQuery("select * from tmp1 where u=13").Check(testkit.Rows("3 13 103")) + tk.MustExec("update tmp1 set v=999 where id=2") + tk.MustQuery("select * from tmp1 where id=2").Check(testkit.Rows("2 12 999")) + tk.MustExec("commit") + + // check point get after transaction + tk.MustQuery("select * from tmp1 where id=3").Check(testkit.Rows("3 13 103")) + tk.MustQuery("select * from tmp1 where u=13").Check(testkit.Rows("3 13 103")) + tk.MustQuery("select * from tmp1 where id=2").Check(testkit.Rows("2 12 999")) +} diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 7166f6ec44bb6..8a0a99f42c1e6 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -15,6 +15,7 @@ package variable import ( "bytes" + "context" "crypto/tls" "encoding/binary" "fmt" @@ -859,6 +860,9 @@ type SessionVars struct { // LocalTemporaryTables is *infoschema.LocalTemporaryTables, use interface to avoid circle dependency. // It's nil if there is no local temporary table. LocalTemporaryTables interface{} + + // TemporaryTableData stores committed kv values for temporary table for current session. + TemporaryTableData kv.MemBuffer } // AllocMPPTaskID allocates task id for mpp tasks. It will reset the task id if the query's @@ -2199,3 +2203,40 @@ func (s *SessionVars) GetSeekFactor(tbl *model.TableInfo) float64 { } return s.seekFactor } + +// GetTemporaryTableSnapshotValue get temporary table value from session +func (s *SessionVars) GetTemporaryTableSnapshotValue(ctx context.Context, key kv.Key) ([]byte, error) { + memData := s.TemporaryTableData + if memData == nil { + return nil, kv.ErrNotExist + } + + v, err := memData.Get(ctx, key) + if err != nil { + return v, err + } + + if len(v) == 0 { + return nil, kv.ErrNotExist + } + + return v, nil +} + +// GetTemporaryTableTxnValue returns a kv.Getter to fetch temporary table data in txn +func (s *SessionVars) GetTemporaryTableTxnValue(ctx context.Context, txn kv.Transaction, key kv.Key) ([]byte, error) { + v, err := txn.GetMemBuffer().Get(ctx, key) + if err == nil { + if len(v) == 0 { + return nil, kv.ErrNotExist + } + + return v, nil + } + + if !kv.IsErrNotFound(err) { + return v, err + } + + return s.GetTemporaryTableSnapshotValue(ctx, key) +} diff --git a/table/tables/index.go b/table/tables/index.go index aef03d0590aaa..ae0eca1339482 100644 --- a/table/tables/index.go +++ b/table/tables/index.go @@ -199,7 +199,10 @@ func (c *index) Create(sctx sessionctx.Context, txn kv.Transaction, indexedValue } var value []byte - if sctx.GetSessionVars().LazyCheckKeyNotExists() { + if c.tblInfo.TempTableType != model.TempTableNone { + // Always check key for temporary table because it does not write to TiKV + value, err = sctx.GetSessionVars().GetTemporaryTableTxnValue(ctx, txn, key) + } else if sctx.GetSessionVars().LazyCheckKeyNotExists() { value, err = txn.GetMemBuffer().Get(ctx, key) } else { value, err = txn.Get(ctx, key) diff --git a/table/tables/tables.go b/table/tables/tables.go index 5ef1707ad8010..bb89cf7f7a26f 100644 --- a/table/tables/tables.go +++ b/table/tables/tables.go @@ -770,7 +770,10 @@ func (t *TableCommon) AddRecord(sctx sessionctx.Context, r []types.Datum, opts . var setPresume bool skipCheck := sctx.GetSessionVars().StmtCtx.BatchCheck if (t.meta.IsCommonHandle || t.meta.PKIsHandle) && !skipCheck && !opt.SkipHandleCheck { - if sctx.GetSessionVars().LazyCheckKeyNotExists() { + if t.meta.TempTableType != model.TempTableNone { + // Always check key for temporary table because it does not write to TiKV + _, err = sctx.GetSessionVars().GetTemporaryTableTxnValue(ctx, txn, key) + } else if sctx.GetSessionVars().LazyCheckKeyNotExists() { var v []byte v, err = txn.GetMemBuffer().Get(ctx, key) if err != nil { @@ -1827,6 +1830,8 @@ type TemporaryTable struct { autoIDAllocator autoid.Allocator // Table size. size int64 + + meta *model.TableInfo } // TempTableFromMeta builds a TempTable from model.TableInfo. @@ -1835,6 +1840,7 @@ func TempTableFromMeta(tblInfo *model.TableInfo) tableutil.TempTable { modified: false, stats: statistics.PseudoTable(tblInfo), autoIDAllocator: autoid.NewAllocatorFromTempTblInfo(tblInfo), + meta: tblInfo, } } @@ -1867,3 +1873,8 @@ func (t *TemporaryTable) GetSize() int64 { func (t *TemporaryTable) SetSize(v int64) { t.size = v } + +// GetMeta gets the table meta. +func (t *TemporaryTable) GetMeta() *model.TableInfo { + return t.meta +} diff --git a/util/tableutil/tableutil.go b/util/tableutil/tableutil.go index bf5d7caac2732..446cd170333aa 100644 --- a/util/tableutil/tableutil.go +++ b/util/tableutil/tableutil.go @@ -36,6 +36,8 @@ type TempTable interface { GetSize() int64 SetSize(int64) + + GetMeta() *model.TableInfo } // TempTableFromMeta builds a TempTable from *model.TableInfo.