Skip to content

Commit

Permalink
server: avoid reusing cached stmt ctx on cursor read (#40023)
Browse files Browse the repository at this point in the history
close #39998
  • Loading branch information
zyguan authored and ti-chi-bot committed Mar 2, 2023
1 parent c78e413 commit 9596975
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 5 deletions.
2 changes: 1 addition & 1 deletion executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -1937,7 +1937,7 @@ func (e *UnionExec) Close() error {
func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
vars := ctx.GetSessionVars()
var sc *stmtctx.StatementContext
if vars.TxnCtx.CouldRetry {
if vars.TxnCtx.CouldRetry || mysql.HasCursorExistsFlag(vars.Status) {
// Must construct new statement context object, the retry history need context for every statement.
// TODO: Maybe one day we can get rid of transaction retry, then this logic can be deleted.
sc = &stmtctx.StatementContext{}
Expand Down
17 changes: 13 additions & 4 deletions server/conn_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,10 @@ func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err e
return mysql.NewErrf(mysql.ErrUnknown, "unsupported flag: CursorTypeScrollable", nil)
}

if !useCursor {
if useCursor {
cc.ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusCursorExists, true)
defer cc.ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusCursorExists, false)
} else {
// not using streaming ,can reuse chunk
cc.ctx.GetSessionVars().SetAlloc(cc.chunkAlloc)
}
Expand Down Expand Up @@ -251,7 +254,8 @@ func (cc *clientConn) executePlanCacheStmt(ctx context.Context, stmt interface{}
// The first return value indicates whether the call of executePreparedStmtAndWriteResult has no side effect and can be retried.
// Currently the first return value is used to fallback to TiKV when TiFlash is down.
func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stmt PreparedStatement, args []expression.Expression, useCursor bool) (bool, error) {
prepStmt, err := (&cc.ctx).GetSessionVars().GetPreparedStmtByID(uint32(stmt.ID()))
vars := (&cc.ctx).GetSessionVars()
prepStmt, err := vars.GetPreparedStmtByID(uint32(stmt.ID()))
if err != nil {
return true, errors.Annotate(err, cc.preparedStmt2String(uint32(stmt.ID())))
}
Expand All @@ -274,6 +278,9 @@ func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stm
return true, errors.Annotate(err, cc.preparedStmt2String(uint32(stmt.ID())))
}
if rs == nil {
if useCursor {
vars.SetStatusFlag(mysql.ServerStatusCursorExists, false)
}
return false, cc.writeOK(ctx)
}
// since there are multiple implementations of ResultSet (the rs might be wrapped), we have to unwrap the rs before
Expand Down Expand Up @@ -304,7 +311,7 @@ func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stm
cl.OnFetchReturned()
}
// explicitly flush columnInfo to client.
err = cc.writeEOF(ctx, cc.ctx.Status()|mysql.ServerStatusCursorExists)
err = cc.writeEOF(ctx, cc.ctx.Status())
if err != nil {
return false, err
}
Expand All @@ -326,6 +333,8 @@ const (
func (cc *clientConn) handleStmtFetch(ctx context.Context, data []byte) (err error) {
cc.ctx.GetSessionVars().StartTime = time.Now()
cc.ctx.GetSessionVars().ClearAlloc(nil, false)
cc.ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusCursorExists, true)
defer cc.ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusCursorExists, false)

stmtID, fetchSize, err := parseStmtFetchCmd(data)
if err != nil {
Expand Down Expand Up @@ -354,7 +363,7 @@ func (cc *clientConn) handleStmtFetch(ctx context.Context, data []byte) (err err
strconv.FormatUint(uint64(stmtID), 10), "stmt_fetch_rs"), cc.preparedStmt2String(stmtID))
}

_, err = cc.writeResultset(ctx, rs, true, cc.ctx.Status()|mysql.ServerStatusCursorExists, int(fetchSize))
_, err = cc.writeResultset(ctx, rs, true, cc.ctx.Status(), int(fetchSize))
if err != nil {
return errors.Annotate(err, cc.preparedStmt2String(stmtID))
}
Expand Down
57 changes: 57 additions & 0 deletions server/conn_stmt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package server

import (
"bytes"
"context"
"encoding/binary"
"testing"
Expand Down Expand Up @@ -340,3 +341,59 @@ func TestCursorReadHoldTS(t *testing.T) {
require.Zero(t, tk.Session().ShowProcess().GetMinStartTS(0))
require.Zero(t, srv.GetMinStartTS(0))
}

func TestCursorExistsFlag(t *testing.T) {
store, dom := testkit.CreateMockStoreAndDomain(t)
srv := CreateMockServer(t, store)
srv.SetDomain(dom)
defer srv.Close()

appendUint32 := binary.LittleEndian.AppendUint32
ctx := context.Background()
c := CreateMockConn(t, srv).(*mockConn)
out := new(bytes.Buffer)
c.pkt.bufWriter.Reset(out)
c.capability |= mysql.ClientDeprecateEOF | mysql.ClientProtocol41
tk := testkit.NewTestKitWithSession(t, store, c.Context().Session)
tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a int primary key)")
tk.MustExec("insert into t values (1), (2), (3), (4), (5), (6), (7), (8)")
tk.MustQuery("select count(*) from t").Check(testkit.Rows("8"))

getLastStatus := func() uint16 {
raw := out.Bytes()
return binary.LittleEndian.Uint16(raw[len(raw)-4 : len(raw)-2])
}

stmt, _, _, err := c.Context().Prepare("select * from t")
require.NoError(t, err)

require.NoError(t, c.Dispatch(ctx, append(
appendUint32([]byte{mysql.ComStmtExecute}, uint32(stmt.ID())),
mysql.CursorTypeReadOnly, 0x1, 0x0, 0x0, 0x0,
)))
require.True(t, mysql.HasCursorExistsFlag(getLastStatus()))

// fetch first 5
require.NoError(t, c.Dispatch(ctx, appendUint32(appendUint32([]byte{mysql.ComStmtFetch}, uint32(stmt.ID())), 5)))
require.True(t, mysql.HasCursorExistsFlag(getLastStatus()))

// COM_QUERY during fetch
require.NoError(t, c.Dispatch(ctx, append([]byte{mysql.ComQuery}, "select * from t"...)))
require.False(t, mysql.HasCursorExistsFlag(getLastStatus()))

// fetch last 3
require.NoError(t, c.Dispatch(ctx, appendUint32(appendUint32([]byte{mysql.ComStmtFetch}, uint32(stmt.ID())), 5)))
require.True(t, mysql.HasCursorExistsFlag(getLastStatus()))

// final fetch with no row retured
// (tidb doesn't unset cursor-exists flag in the previous response like mysql, one more fetch is needed)
require.NoError(t, c.Dispatch(ctx, appendUint32(appendUint32([]byte{mysql.ComStmtFetch}, uint32(stmt.ID())), 5)))
require.False(t, mysql.HasCursorExistsFlag(getLastStatus()))
require.True(t, getLastStatus()&mysql.ServerStatusLastRowSend > 0)

// COM_QUERY after fetch
require.NoError(t, c.Dispatch(ctx, append([]byte{mysql.ComQuery}, "select * from t"...)))
require.False(t, mysql.HasCursorExistsFlag(getLastStatus()))
}

0 comments on commit 9596975

Please sign in to comment.