diff --git a/dbms/src/Flash/Mpp/ExchangeReceiver.cpp b/dbms/src/Flash/Mpp/ExchangeReceiver.cpp index 88f29305dbe..9f75d8163fc 100644 --- a/dbms/src/Flash/Mpp/ExchangeReceiver.cpp +++ b/dbms/src/Flash/Mpp/ExchangeReceiver.cpp @@ -314,9 +314,10 @@ ExchangeReceiverBase::ExchangeReceiverBase( , enable_fine_grained_shuffle_flag(enableFineGrainedShuffle(fine_grained_shuffle_stream_count_)) , output_stream_count(enable_fine_grained_shuffle_flag ? std::min(max_streams_, fine_grained_shuffle_stream_count_) : max_streams_) , max_buffer_size(std::max(batch_packet_count, std::max(source_num, max_streams_) * 2)) + , connection_uncreated_num(source_num) , thread_manager(newThreadManager()) - , live_connections(0) , live_local_connections(0) + , live_connections(source_num) , state(ExchangeReceiverState::NORMAL) , exc_log(Logger::get(req_id, executor_id)) , collected(false) @@ -336,6 +337,7 @@ ExchangeReceiverBase::ExchangeReceiverBase( { try { + handleConnectionAfterException(); cancel(); thread_manager->wait(); } @@ -366,6 +368,16 @@ ExchangeReceiverBase::~ExchangeReceiverBase() } } +template +void ExchangeReceiverBase::handleConnectionAfterException() +{ + std::lock_guard lock(mu); + live_connections -= connection_uncreated_num; + + // some cv may have been blocked, wake them up and recheck the condition. + cv.notify_all(); +} + template void ExchangeReceiverBase::waitAllConnectionDone() { @@ -424,24 +436,9 @@ template void ExchangeReceiverBase::addLocalConnectionNum() { std::lock_guard lock(mu); - ++live_connections; ++live_local_connections; } -template -void ExchangeReceiverBase::addSyncConnectionNum() -{ - std::lock_guard lock(mu); - ++live_connections; -} - -template -void ExchangeReceiverBase::addAsyncConnectionNum(Int32 conn_num) -{ - std::lock_guard lock(mu); - live_connections += conn_num; -} - template void ExchangeReceiverBase::setUpConnection() { @@ -475,6 +472,7 @@ void ExchangeReceiverBase::setUpConnection() req.source_index, local_request_handler, enable_fine_grained_shuffle_flag); + --connection_uncreated_num; } else { @@ -486,12 +484,14 @@ void ExchangeReceiverBase::setUpConnection() }); ++thread_count; + --connection_uncreated_num; } } // TODO: reduce this thread in the future. if (!async_requests.empty()) { + auto async_conn_num = async_requests.size(); thread_manager->schedule(true, "RecvReactor", [this, async_requests = std::move(async_requests)] { if (enable_fine_grained_shuffle_flag) reactor(async_requests); @@ -500,6 +500,7 @@ void ExchangeReceiverBase::setUpConnection() }); ++thread_count; + connection_uncreated_num -= async_conn_num; } } @@ -517,7 +518,6 @@ void ExchangeReceiverBase::reactor(const std::vector & asyn CPUAffinityManager::getInstance().bindSelfQueryThread(); size_t alive_async_connections = async_requests.size(); - addAsyncConnectionNum(alive_async_connections); MPMCQueue ready_requests(alive_async_connections * 2); std::vector> handlers; @@ -550,8 +550,6 @@ template template void ExchangeReceiverBase::readLoop(const Request & req) { - addSyncConnectionNum(); - GET_METRIC(tiflash_thread_count, type_threads_of_receiver_read_loop).Increment(); SCOPE_EXIT({ GET_METRIC(tiflash_thread_count, type_threads_of_receiver_read_loop).Decrement(); @@ -818,7 +816,7 @@ void ExchangeReceiverBase::connectionDone( const String & local_err_msg, const LoggerPtr & log) { - Int32 copy_live_conn = -1; + Int32 copy_live_connections; { std::lock_guard lock(mu); @@ -829,7 +827,8 @@ void ExchangeReceiverBase::connectionDone( if (err_msg.empty()) err_msg = local_err_msg; } - copy_live_conn = --live_connections; + + copy_live_connections = --live_connections; } LOG_DEBUG( @@ -837,18 +836,21 @@ void ExchangeReceiverBase::connectionDone( "connection end. meet error: {}, err msg: {}, current alive connections: {}", meet_error, local_err_msg, - copy_live_conn); + copy_live_connections); - if (copy_live_conn == 0) + if (copy_live_connections == 0) { LOG_DEBUG(log, "All threads end in ExchangeReceiver"); cv.notify_all(); } - else if (copy_live_conn < 0) - throw Exception("live_connections should not be less than 0!"); + else if (copy_live_connections < 0) + throw Exception("alive_connection_num should not be less than 0!"); - if (meet_error || copy_live_conn == 0) + if (meet_error || copy_live_connections == 0) + { + LOG_INFO(exc_log, "receiver channels finished"); finishAllMsgChannels(); + } } template diff --git a/dbms/src/Flash/Mpp/ExchangeReceiver.h b/dbms/src/Flash/Mpp/ExchangeReceiver.h index 087b7c45c97..bcef3dfe427 100644 --- a/dbms/src/Flash/Mpp/ExchangeReceiver.h +++ b/dbms/src/Flash/Mpp/ExchangeReceiver.h @@ -161,10 +161,8 @@ class ExchangeReceiverBase private: void prepareMsgChannels(); void addLocalConnectionNum(); - void addSyncConnectionNum(); - void addAsyncConnectionNum(Int32 conn_num); - void connectionLocalDone(); + void handleConnectionAfterException(); bool isReceiverForTiFlashStorage() { @@ -180,6 +178,7 @@ class ExchangeReceiverBase const bool enable_fine_grained_shuffle_flag; const size_t output_stream_count; const size_t max_buffer_size; + Int32 connection_uncreated_num; std::shared_ptr thread_manager; DAGSchema schema; @@ -189,8 +188,8 @@ class ExchangeReceiverBase std::mutex mu; std::condition_variable cv; /// should lock `mu` when visit these members - Int32 live_connections; Int32 live_local_connections; + Int32 live_connections; ExchangeReceiverState state; String err_msg;