diff --git a/dbms/src/Flash/Mpp/MPPHandler.cpp b/dbms/src/Flash/Mpp/MPPHandler.cpp index ca4838942d4..f7d0f14c9ea 100644 --- a/dbms/src/Flash/Mpp/MPPHandler.cpp +++ b/dbms/src/Flash/Mpp/MPPHandler.cpp @@ -6,7 +6,6 @@ namespace DB { - namespace FailPoints { extern const char exception_before_mpp_non_root_task_run[]; @@ -17,11 +16,8 @@ void MPPHandler::handleError(const MPPTaskPtr & task, String error) { try { - if (task != nullptr) - { - task->closeAllTunnel(error); - task->unregisterTask(); - } + if (task) + task->cancel(error); } catch (...) { diff --git a/dbms/src/Flash/Mpp/MPPTask.cpp b/dbms/src/Flash/Mpp/MPPTask.cpp index 8fbbce4b52a..b6de63095d8 100644 --- a/dbms/src/Flash/Mpp/MPPTask.cpp +++ b/dbms/src/Flash/Mpp/MPPTask.cpp @@ -56,11 +56,11 @@ MPPTask::~MPPTask() /// MPPTask maybe destructed by different thread, set the query memory_tracker /// to current_memory_tracker in the destructor current_memory_tracker = memory_tracker; - closeAllTunnel(""); + closeAllTunnels(""); LOG_DEBUG(log, "finish MPPTask: " << id.toString()); } -void MPPTask::closeAllTunnel(const String & reason) +void MPPTask::closeAllTunnels(const String & reason) { for (auto & it : tunnel_map) { @@ -305,12 +305,12 @@ void MPPTask::preprocess() void MPPTask::runImpl() { - auto old_status = static_cast(INITIALIZING); - if (!status.compare_exchange_strong(old_status, static_cast(RUNNING))) + if (!switchStatus(INITIALIZING, RUNNING)) { LOG_WARNING(log, "task not in initializing state, skip running"); return; } + current_memory_tracker = memory_tracker; Stopwatch stopwatch; GET_METRIC(tiflash_coprocessor_request_count, type_run_mpp_task).Increment(); @@ -398,14 +398,18 @@ void MPPTask::runImpl() } else { - writeErrToAllTunnel(err_msg); + writeErrToAllTunnels(err_msg); } LOG_INFO(log, "task ends, time cost is " << std::to_string(stopwatch.elapsedMilliseconds()) << " ms."); unregisterTask(); - status = FINISHED; + + if (switchStatus(RUNNING, FINISHED)) + LOG_INFO(log, "finish task"); + else + LOG_WARNING(log, "finish task which was cancelled before"); } -void MPPTask::writeErrToAllTunnel(const String & e) +void MPPTask::writeErrToAllTunnels(const String & e) { for (auto & it : tunnel_map) { @@ -424,23 +428,36 @@ void MPPTask::writeErrToAllTunnel(const String & e) void MPPTask::cancel(const String & reason) { - auto current_status = status.exchange(CANCELLED); - if (current_status == FINISHED || current_status == CANCELLED) + LOG_WARNING(log, "Begin cancel task: " + id.toString()); + while (true) { - if (current_status == FINISHED) - status = FINISHED; - return; + auto previous_status = status.load(); + if (previous_status == FINISHED || previous_status == CANCELLED) + { + LOG_WARNING(log, "task already " << (previous_status == FINISHED ? "finished" : "cancelled")); + return; + } + else if (previous_status == INITIALIZING && switchStatus(INITIALIZING, CANCELLED)) + { + closeAllTunnels(reason); + unregisterTask(); + LOG_WARNING(log, "Finish cancel task from uninitialized"); + return; + } + else if (previous_status == RUNNING && switchStatus(RUNNING, CANCELLED)) + { + context.getProcessList().sendCancelToQuery(context.getCurrentQueryId(), context.getClientInfo().current_user, true); + closeAllTunnels(reason); + /// runImpl is running, leave remaining work to runImpl + LOG_WARNING(log, "Finish cancel task from running"); + return; + } } - LOG_WARNING(log, "Begin cancel task: " + id.toString()); - /// step 1. cancel query streams if it is running - if (current_status == RUNNING) - context.getProcessList().sendCancelToQuery(context.getCurrentQueryId(), context.getClientInfo().current_user, true); - /// step 2. write Error msg and close the tunnel. - /// Here we use `closeAllTunnel` because currently, `cancel` is a query level cancel, which - /// means if this mpp task is cancelled, all the mpp tasks belonging to the same query are - /// cancelled at the same time, so there is no guarantee that the tunnel can be connected. - closeAllTunnel(reason); - LOG_WARNING(log, "Finish cancel task: " + id.toString()); +} + +bool MPPTask::switchStatus(TaskStatus from, TaskStatus to) +{ + return status.compare_exchange_strong(from, to); } } // namespace DB diff --git a/dbms/src/Flash/Mpp/MPPTask.h b/dbms/src/Flash/Mpp/MPPTask.h index a6f99314d29..e7c56d6b114 100644 --- a/dbms/src/Flash/Mpp/MPPTask.h +++ b/dbms/src/Flash/Mpp/MPPTask.h @@ -48,20 +48,10 @@ class MPPTask : public std::enable_shared_from_this bool isRootMPPTask() const { return dag_context->isRootMPPTask(); } - TaskStatus getStatus() const { return static_cast(status.load()); } - - void unregisterTask(); + TaskStatus getStatus() const { return status.load(); } void cancel(const String & reason); - /// Similar to `writeErrToAllTunnel`, but it just try to write the error message to tunnel - /// without waiting the tunnel to be connected - void closeAllTunnel(const String & reason); - - void finishWrite(); - - void writeErrToAllTunnel(const String & e); - std::vector prepare(const mpp::DispatchTaskRequest & task_request); void preprocess(); @@ -82,6 +72,18 @@ class MPPTask : public std::enable_shared_from_this void runImpl(); + void unregisterTask(); + + void writeErrToAllTunnels(const String & e); + + /// Similar to `writeErrToAllTunnels`, but it just try to write the error message to tunnel + /// without waiting the tunnel to be connected + void closeAllTunnels(const String & reason); + + void finishWrite(); + + bool switchStatus(TaskStatus from, TaskStatus to); + Context context; RegionInfoMap local_regions; @@ -97,7 +99,7 @@ class MPPTask : public std::enable_shared_from_this MPPTaskId id; - std::atomic status{INITIALIZING}; + std::atomic status{INITIALIZING}; mpp::TaskMeta meta; MPPTunnelSetPtr tunnel_set;