diff --git a/executor/executor.go b/executor/executor.go index 08d137e252b59..98585b905d026 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -1937,7 +1937,7 @@ func (e *UnionExec) Close() error { func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { vars := ctx.GetSessionVars() var sc *stmtctx.StatementContext - if vars.TxnCtx.CouldRetry { + if vars.TxnCtx.CouldRetry || mysql.HasCursorExistsFlag(vars.Status) { // Must construct new statement context object, the retry history need context for every statement. // TODO: Maybe one day we can get rid of transaction retry, then this logic can be deleted. sc = &stmtctx.StatementContext{} diff --git a/server/conn_stmt.go b/server/conn_stmt.go index cf2f9f2aa6e86..19e77ce222d51 100644 --- a/server/conn_stmt.go +++ b/server/conn_stmt.go @@ -158,7 +158,10 @@ func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err e return mysql.NewErrf(mysql.ErrUnknown, "unsupported flag: CursorTypeScrollable", nil) } - if !useCursor { + if useCursor { + cc.ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusCursorExists, true) + defer cc.ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusCursorExists, false) + } else { // not using streaming ,can reuse chunk cc.ctx.GetSessionVars().SetAlloc(cc.chunkAlloc) } @@ -251,7 +254,8 @@ func (cc *clientConn) executePlanCacheStmt(ctx context.Context, stmt interface{} // The first return value indicates whether the call of executePreparedStmtAndWriteResult has no side effect and can be retried. // Currently the first return value is used to fallback to TiKV when TiFlash is down. func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stmt PreparedStatement, args []expression.Expression, useCursor bool) (bool, error) { - prepStmt, err := (&cc.ctx).GetSessionVars().GetPreparedStmtByID(uint32(stmt.ID())) + vars := (&cc.ctx).GetSessionVars() + prepStmt, err := vars.GetPreparedStmtByID(uint32(stmt.ID())) if err != nil { return true, errors.Annotate(err, cc.preparedStmt2String(uint32(stmt.ID()))) } @@ -274,6 +278,9 @@ func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stm return true, errors.Annotate(err, cc.preparedStmt2String(uint32(stmt.ID()))) } if rs == nil { + if useCursor { + vars.SetStatusFlag(mysql.ServerStatusCursorExists, false) + } return false, cc.writeOK(ctx) } // since there are multiple implementations of ResultSet (the rs might be wrapped), we have to unwrap the rs before @@ -304,7 +311,7 @@ func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stm cl.OnFetchReturned() } // explicitly flush columnInfo to client. - err = cc.writeEOF(ctx, cc.ctx.Status()|mysql.ServerStatusCursorExists) + err = cc.writeEOF(ctx, cc.ctx.Status()) if err != nil { return false, err } @@ -326,6 +333,8 @@ const ( func (cc *clientConn) handleStmtFetch(ctx context.Context, data []byte) (err error) { cc.ctx.GetSessionVars().StartTime = time.Now() cc.ctx.GetSessionVars().ClearAlloc(nil, false) + cc.ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusCursorExists, true) + defer cc.ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusCursorExists, false) stmtID, fetchSize, err := parseStmtFetchCmd(data) if err != nil { @@ -354,7 +363,7 @@ func (cc *clientConn) handleStmtFetch(ctx context.Context, data []byte) (err err strconv.FormatUint(uint64(stmtID), 10), "stmt_fetch_rs"), cc.preparedStmt2String(stmtID)) } - _, err = cc.writeResultset(ctx, rs, true, cc.ctx.Status()|mysql.ServerStatusCursorExists, int(fetchSize)) + _, err = cc.writeResultset(ctx, rs, true, cc.ctx.Status(), int(fetchSize)) if err != nil { return errors.Annotate(err, cc.preparedStmt2String(stmtID)) } diff --git a/server/conn_stmt_test.go b/server/conn_stmt_test.go index 2e60fc1085332..366e5c54ac222 100644 --- a/server/conn_stmt_test.go +++ b/server/conn_stmt_test.go @@ -15,6 +15,7 @@ package server import ( + "bytes" "context" "encoding/binary" "testing" @@ -340,3 +341,59 @@ func TestCursorReadHoldTS(t *testing.T) { require.Zero(t, tk.Session().ShowProcess().GetMinStartTS(0)) require.Zero(t, srv.GetMinStartTS(0)) } + +func TestCursorExistsFlag(t *testing.T) { + store, dom := testkit.CreateMockStoreAndDomain(t) + srv := CreateMockServer(t, store) + srv.SetDomain(dom) + defer srv.Close() + + appendUint32 := binary.LittleEndian.AppendUint32 + ctx := context.Background() + c := CreateMockConn(t, srv).(*mockConn) + out := new(bytes.Buffer) + c.pkt.bufWriter.Reset(out) + c.capability |= mysql.ClientDeprecateEOF | mysql.ClientProtocol41 + tk := testkit.NewTestKitWithSession(t, store, c.Context().Session) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int primary key)") + tk.MustExec("insert into t values (1), (2), (3), (4), (5), (6), (7), (8)") + tk.MustQuery("select count(*) from t").Check(testkit.Rows("8")) + + getLastStatus := func() uint16 { + raw := out.Bytes() + return binary.LittleEndian.Uint16(raw[len(raw)-4 : len(raw)-2]) + } + + stmt, _, _, err := c.Context().Prepare("select * from t") + require.NoError(t, err) + + require.NoError(t, c.Dispatch(ctx, append( + appendUint32([]byte{mysql.ComStmtExecute}, uint32(stmt.ID())), + mysql.CursorTypeReadOnly, 0x1, 0x0, 0x0, 0x0, + ))) + require.True(t, mysql.HasCursorExistsFlag(getLastStatus())) + + // fetch first 5 + require.NoError(t, c.Dispatch(ctx, appendUint32(appendUint32([]byte{mysql.ComStmtFetch}, uint32(stmt.ID())), 5))) + require.True(t, mysql.HasCursorExistsFlag(getLastStatus())) + + // COM_QUERY during fetch + require.NoError(t, c.Dispatch(ctx, append([]byte{mysql.ComQuery}, "select * from t"...))) + require.False(t, mysql.HasCursorExistsFlag(getLastStatus())) + + // fetch last 3 + require.NoError(t, c.Dispatch(ctx, appendUint32(appendUint32([]byte{mysql.ComStmtFetch}, uint32(stmt.ID())), 5))) + require.True(t, mysql.HasCursorExistsFlag(getLastStatus())) + + // final fetch with no row retured + // (tidb doesn't unset cursor-exists flag in the previous response like mysql, one more fetch is needed) + require.NoError(t, c.Dispatch(ctx, appendUint32(appendUint32([]byte{mysql.ComStmtFetch}, uint32(stmt.ID())), 5))) + require.False(t, mysql.HasCursorExistsFlag(getLastStatus())) + require.True(t, getLastStatus()&mysql.ServerStatusLastRowSend > 0) + + // COM_QUERY after fetch + require.NoError(t, c.Dispatch(ctx, append([]byte{mysql.ComQuery}, "select * from t"...))) + require.False(t, mysql.HasCursorExistsFlag(getLastStatus())) +}