Skip to content

Commit

Permalink
Refine MPPTask's status management (#2466)
Browse files Browse the repository at this point in the history
  • Loading branch information
fuzhe1989 authored Sep 9, 2021
1 parent 92e6681 commit d0825f0
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 40 deletions.
8 changes: 2 additions & 6 deletions dbms/src/Flash/Mpp/MPPHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

namespace DB
{

namespace FailPoints
{
extern const char exception_before_mpp_non_root_task_run[];
Expand All @@ -17,11 +16,8 @@ void MPPHandler::handleError(const MPPTaskPtr & task, String error)
{
try
{
if (task != nullptr)
{
task->closeAllTunnel(error);
task->unregisterTask();
}
if (task)
task->cancel(error);
}
catch (...)
{
Expand Down
61 changes: 39 additions & 22 deletions dbms/src/Flash/Mpp/MPPTask.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ MPPTask::~MPPTask()
/// MPPTask maybe destructed by different thread, set the query memory_tracker
/// to current_memory_tracker in the destructor
current_memory_tracker = memory_tracker;
closeAllTunnel("");
closeAllTunnels("");
LOG_DEBUG(log, "finish MPPTask: " << id.toString());
}

void MPPTask::closeAllTunnel(const String & reason)
void MPPTask::closeAllTunnels(const String & reason)
{
for (auto & it : tunnel_map)
{
Expand Down Expand Up @@ -305,12 +305,12 @@ void MPPTask::preprocess()

void MPPTask::runImpl()
{
auto old_status = static_cast<Int32>(INITIALIZING);
if (!status.compare_exchange_strong(old_status, static_cast<Int32>(RUNNING)))
if (!switchStatus(INITIALIZING, RUNNING))
{
LOG_WARNING(log, "task not in initializing state, skip running");
return;
}

current_memory_tracker = memory_tracker;
Stopwatch stopwatch;
GET_METRIC(tiflash_coprocessor_request_count, type_run_mpp_task).Increment();
Expand Down Expand Up @@ -398,14 +398,18 @@ void MPPTask::runImpl()
}
else
{
writeErrToAllTunnel(err_msg);
writeErrToAllTunnels(err_msg);
}
LOG_INFO(log, "task ends, time cost is " << std::to_string(stopwatch.elapsedMilliseconds()) << " ms.");
unregisterTask();
status = FINISHED;

if (switchStatus(RUNNING, FINISHED))
LOG_INFO(log, "finish task");
else
LOG_WARNING(log, "finish task which was cancelled before");
}

void MPPTask::writeErrToAllTunnel(const String & e)
void MPPTask::writeErrToAllTunnels(const String & e)
{
for (auto & it : tunnel_map)
{
Expand All @@ -424,23 +428,36 @@ void MPPTask::writeErrToAllTunnel(const String & e)

void MPPTask::cancel(const String & reason)
{
auto current_status = status.exchange(CANCELLED);
if (current_status == FINISHED || current_status == CANCELLED)
LOG_WARNING(log, "Begin cancel task: " + id.toString());
while (true)
{
if (current_status == FINISHED)
status = FINISHED;
return;
auto previous_status = status.load();
if (previous_status == FINISHED || previous_status == CANCELLED)
{
LOG_WARNING(log, "task already " << (previous_status == FINISHED ? "finished" : "cancelled"));
return;
}
else if (previous_status == INITIALIZING && switchStatus(INITIALIZING, CANCELLED))
{
closeAllTunnels(reason);
unregisterTask();
LOG_WARNING(log, "Finish cancel task from uninitialized");
return;
}
else if (previous_status == RUNNING && switchStatus(RUNNING, CANCELLED))
{
context.getProcessList().sendCancelToQuery(context.getCurrentQueryId(), context.getClientInfo().current_user, true);
closeAllTunnels(reason);
/// runImpl is running, leave remaining work to runImpl
LOG_WARNING(log, "Finish cancel task from running");
return;
}
}
LOG_WARNING(log, "Begin cancel task: " + id.toString());
/// step 1. cancel query streams if it is running
if (current_status == RUNNING)
context.getProcessList().sendCancelToQuery(context.getCurrentQueryId(), context.getClientInfo().current_user, true);
/// step 2. write Error msg and close the tunnel.
/// Here we use `closeAllTunnel` because currently, `cancel` is a query level cancel, which
/// means if this mpp task is cancelled, all the mpp tasks belonging to the same query are
/// cancelled at the same time, so there is no guarantee that the tunnel can be connected.
closeAllTunnel(reason);
LOG_WARNING(log, "Finish cancel task: " + id.toString());
}

bool MPPTask::switchStatus(TaskStatus from, TaskStatus to)
{
return status.compare_exchange_strong(from, to);
}

} // namespace DB
26 changes: 14 additions & 12 deletions dbms/src/Flash/Mpp/MPPTask.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,10 @@ class MPPTask : public std::enable_shared_from_this<MPPTask>

bool isRootMPPTask() const { return dag_context->isRootMPPTask(); }

TaskStatus getStatus() const { return static_cast<TaskStatus>(status.load()); }

void unregisterTask();
TaskStatus getStatus() const { return status.load(); }

void cancel(const String & reason);

/// Similar to `writeErrToAllTunnel`, but it just try to write the error message to tunnel
/// without waiting the tunnel to be connected
void closeAllTunnel(const String & reason);

void finishWrite();

void writeErrToAllTunnel(const String & e);

std::vector<RegionInfo> prepare(const mpp::DispatchTaskRequest & task_request);

void preprocess();
Expand All @@ -82,6 +72,18 @@ class MPPTask : public std::enable_shared_from_this<MPPTask>

void runImpl();

void unregisterTask();

void writeErrToAllTunnels(const String & e);

/// Similar to `writeErrToAllTunnels`, but it just try to write the error message to tunnel
/// without waiting the tunnel to be connected
void closeAllTunnels(const String & reason);

void finishWrite();

bool switchStatus(TaskStatus from, TaskStatus to);

Context context;

RegionInfoMap local_regions;
Expand All @@ -97,7 +99,7 @@ class MPPTask : public std::enable_shared_from_this<MPPTask>

MPPTaskId id;

std::atomic<Int32> status{INITIALIZING};
std::atomic<TaskStatus> status{INITIALIZING};

mpp::TaskMeta meta;
MPPTunnelSetPtr tunnel_set;
Expand Down

0 comments on commit d0825f0

Please sign in to comment.