diff --git a/domain/plan_replayer.go b/domain/plan_replayer.go index d237445f5404d..66f99b96d0fe6 100644 --- a/domain/plan_replayer.go +++ b/domain/plan_replayer.go @@ -299,6 +299,13 @@ func (r *planReplayerDumpTaskStatus) GetRunningTaskStatusLen() int { return len(r.runningTaskMu.runningTasks) } +// CleanFinishedTaskStatus clean then finished tasks, only used for unit test +func (r *planReplayerDumpTaskStatus) CleanFinishedTaskStatus() { + r.finishedTaskMu.Lock() + defer r.finishedTaskMu.Unlock() + r.finishedTaskMu.finishedTask = map[PlanReplayerTaskKey]struct{}{} +} + // GetFinishedTaskStatusLen used for unit test func (r *planReplayerDumpTaskStatus) GetFinishedTaskStatusLen() int { r.finishedTaskMu.RLock() diff --git a/domain/plan_replayer_handle_test.go b/domain/plan_replayer_handle_test.go index dccb400ecd5b6..8a2783af1274b 100644 --- a/domain/plan_replayer_handle_test.go +++ b/domain/plan_replayer_handle_test.go @@ -101,4 +101,23 @@ func TestPlanReplayerHandleDumpTask(t *testing.T) { err = prHandle.CollectPlanReplayerTask() require.NoError(t, err) require.Len(t, prHandle.GetTasks(), 0) + + // clean the task and register task + prHandle.GetTaskStatus().CleanFinishedTaskStatus() + tk.MustExec("delete from mysql.plan_replayer_task") + tk.MustExec("delete from mysql.plan_replayer_status") + tk.MustExec(fmt.Sprintf("insert into mysql.plan_replayer_task (sql_digest, plan_digest) values ('%v','%v');", sqlDigest, "*")) + err = prHandle.CollectPlanReplayerTask() + require.NoError(t, err) + require.Len(t, prHandle.GetTasks(), 1) + tk.MustQuery("select * from t;") + task = prHandle.DrainTask() + require.NotNil(t, task) + worker = prHandle.GetWorker() + success = worker.HandleTask(task) + require.True(t, success) + require.Equal(t, prHandle.GetTaskStatus().GetRunningTaskStatusLen(), 0) + require.Equal(t, prHandle.GetTaskStatus().GetFinishedTaskStatusLen(), 1) + // assert capture * task still remained + require.Len(t, prHandle.GetTasks(), 1) } diff --git a/executor/compiler.go b/executor/compiler.go index 5d16a4fbea6e7..ce8b487e24657 100644 --- a/executor/compiler.go +++ b/executor/compiler.go @@ -176,9 +176,11 @@ func checkPlanReplayerCaptureTask(sctx sessionctx.Context, stmtNode ast.StmtNode _, sqlDigest := sctx.GetSessionVars().StmtCtx.SQLDigest() _, planDigest := getPlanDigest(sctx.GetSessionVars().StmtCtx) for _, task := range tasks { - if task.SQLDigest == sqlDigest.String() && task.PlanDigest == planDigest.String() { - sendPlanReplayerDumpTask(sqlDigest.String(), planDigest.String(), sctx, stmtNode) - return + if task.SQLDigest == sqlDigest.String() { + if task.PlanDigest == "*" || task.PlanDigest == planDigest.String() { + sendPlanReplayerDumpTask(sqlDigest.String(), planDigest.String(), sctx, stmtNode) + return + } } } }