diff --git a/dbms/src/Flash/Coprocessor/InterpreterDAG.cpp b/dbms/src/Flash/Coprocessor/InterpreterDAG.cpp index d99e2c26d35..6a528b02f5f 100644 --- a/dbms/src/Flash/Coprocessor/InterpreterDAG.cpp +++ b/dbms/src/Flash/Coprocessor/InterpreterDAG.cpp @@ -59,12 +59,11 @@ void InterpreterDAG::initMPPExchangeReceiver(const DAGQueryBlock & dag_query_blo } if (dag_query_block.source->tp() == tipb::ExecType::TypeExchangeReceiver) { - /// use max_streams * 5 as the default receiver buffer size, maybe make it more configurable mpp_exchange_receiver_maps[dag_query_block.source_name] = std::make_shared( context, dag_query_block.source->exchange_receiver(), dag.getDAGContext().getMPPTaskMeta(), - max_streams * 5); + max_streams); } } diff --git a/dbms/src/Flash/Mpp/ExchangeReceiver.cpp b/dbms/src/Flash/Mpp/ExchangeReceiver.cpp index ba02448ab8b..70e52e89a26 100644 --- a/dbms/src/Flash/Mpp/ExchangeReceiver.cpp +++ b/dbms/src/Flash/Mpp/ExchangeReceiver.cpp @@ -42,6 +42,23 @@ void ExchangeReceiver::setUpConnection() } } +String getReceiverStateStr(const State & s) +{ + switch (s) + { + case NORMAL: + return "NORMAL"; + case ERROR: + return "ERROR"; + case CANCELED: + return "CANCELED"; + case CLOSED: + return "CLOSED"; + default: + return "UNKNOWN"; + } +} + void ExchangeReceiver::ReadLoop(const String & meta_raw, size_t source_index) { bool meet_error = false; @@ -53,8 +70,11 @@ void ExchangeReceiver::ReadLoop(const String & meta_raw, size_t source_index) try { auto sender_task = new mpp::TaskMeta(); + if (!sender_task->ParseFromString(meta_raw)) + { + throw Exception("parse task meta error!"); + } send_task_id = sender_task->task_id(); - sender_task->ParseFromString(meta_raw); auto req = std::make_shared(); req->set_allocated_receiver_meta(new mpp::TaskMeta(task_meta)); req->set_allocated_sender_meta(sender_task); @@ -66,27 +86,54 @@ void ExchangeReceiver::ReadLoop(const String & meta_raw, size_t source_index) grpc::ClientContext client_context; auto reader = cluster->rpc_client->sendStreamRequest(req->sender_meta().address(), &client_context, call); reader->WaitForInitialMetadata(); - mpp::MPPDataPacket packet; + std::shared_ptr packet; String req_info = "tunnel" + std::to_string(send_task_id) + "+" + std::to_string(recv_task_id); bool has_data = false; for (;;) { LOG_TRACE(log, "begin next "); - bool success = reader->Read(&packet); + { + std::unique_lock lock(mu); + cv.wait(lock, [&] { return res_buffer.hasEmpty() || state != NORMAL; }); + if (state == NORMAL) + { + res_buffer.popEmpty(packet); + cv.notify_all(); + } + else + { + meet_error = true; + local_err_msg = "receiver's state is " + getReceiverStateStr(state) + ", exit from ReadLoop"; + LOG_WARNING(log, local_err_msg); + break; + } + } + packet->req_info = req_info; + packet->source_index = source_index; + bool success = reader->Read(packet->packet.get()); if (!success) break; else has_data = true; - if (packet.has_error()) + if (packet->packet->has_error()) { - throw Exception("Exchange receiver meet error : " + packet.error().msg()); + throw Exception("Exchange receiver meet error : " + packet->packet->error().msg()); } - if (!decodePacket(packet, source_index, req_info)) { - meet_error = true; - local_err_msg = "Decode packet meet error"; - LOG_WARNING(log, "Decode packet meet error, exit from ReadLoop"); - break; + std::unique_lock lock(mu); + cv.wait(lock, [&] { return res_buffer.canPush() || state != NORMAL; }); + if (state == NORMAL) + { + res_buffer.pushObject(packet); + cv.notify_all(); + } + else + { + meet_error = true; + local_err_msg = "receiver's state is " + getReceiverStateStr(state) + ", exit from ReadLoop"; + LOG_WARNING(log, local_err_msg); + break; + } } } // if meet error, such as decode packect fails, it will not retry. @@ -133,22 +180,79 @@ void ExchangeReceiver::ReadLoop(const String & meta_raw, size_t source_index) meet_error = true; local_err_msg = "fatal error"; } - std::lock_guard lock(mu); - live_connections--; - - // avoid concurrent conflict - Int32 live_conn_copy = live_connections; + Int32 copy_live_conn = -1; + { + std::unique_lock lock(mu); + live_connections--; + if (meet_error && state == NORMAL) + state = ERROR; + if (meet_error && err_msg.empty()) + err_msg = local_err_msg; + copy_live_conn = live_connections; + cv.notify_all(); + } + LOG_DEBUG(log, fmt::format("{} -> {} end! current alive connections: {}", send_task_id, recv_task_id, copy_live_conn)); - if (meet_error && state == NORMAL) - state = ERROR; - if (meet_error && err_msg.empty()) - err_msg = local_err_msg; - cv.notify_all(); + if (copy_live_conn == 0) + LOG_DEBUG(log, fmt::format("All threads end in ExchangeReceiver")); + else if (copy_live_conn < 0) + throw Exception("live_connections should not be less than 0!"); +} - LOG_DEBUG(log, fmt::format("{} -> {} end! current alive connections: {}", send_task_id, recv_task_id, live_conn_copy)); +ExchangeReceiverResult ExchangeReceiver::nextResult() +{ + std::shared_ptr packet; + { + std::unique_lock lock(mu); + cv.wait(lock, [&] { return res_buffer.hasObjects() || live_connections == 0 || state != NORMAL; }); - if (live_conn_copy == 0) - LOG_DEBUG(log, fmt::format("All threads end in ExchangeReceiver")); + if (state != NORMAL) + { + String msg; + if (state == CANCELED) + msg = "query canceled"; + else if (state == CLOSED) + msg = "ExchangeReceiver closed"; + else if (!err_msg.empty()) + msg = err_msg; + else + msg = "Unknown error"; + return {nullptr, 0, "ExchangeReceiver", true, msg, false}; + } + else if (res_buffer.hasObjects()) + { + res_buffer.popObject(packet); + cv.notify_all(); + } + else /// live_connections == 0, res_buffer is empty, and state is NORMAL, that is the end. + { + return {nullptr, 0, "ExchangeReceiver", false, "", true}; + } + } + assert(packet != nullptr && packet->packet != nullptr); + ExchangeReceiverResult result; + if (packet->packet->has_error()) + { + result = {nullptr, packet->source_index, packet->req_info, true, packet->packet->error().msg(), false}; + } + else + { + auto resp_ptr = std::make_shared(); + if (!resp_ptr->ParseFromString(packet->packet->data())) + { + result = {nullptr, packet->source_index, packet->req_info, true, "decode error", false}; + } + else + { + result = {resp_ptr, packet->source_index, packet->req_info}; + } + } + packet->packet->Clear(); + std::unique_lock lock(mu); + cv.wait(lock, [&] { return res_buffer.canPushEmpty(); }); + res_buffer.pushEmpty(std::move(packet)); + cv.notify_all(); + return result; } } // namespace DB diff --git a/dbms/src/Flash/Mpp/ExchangeReceiver.h b/dbms/src/Flash/Mpp/ExchangeReceiver.h index ff86afcf331..64f1a32fc4c 100644 --- a/dbms/src/Flash/Mpp/ExchangeReceiver.h +++ b/dbms/src/Flash/Mpp/ExchangeReceiver.h @@ -53,6 +53,99 @@ enum State CLOSED, }; +struct ReceivedPacket +{ + ReceivedPacket() + { + packet = std::make_shared(); + } + std::shared_ptr packet; + size_t source_index = 0; + String req_info; +}; + +/// RecyclableBuffer recycles unused objects to avoid too much allocation of objects. +template +class RecyclableBuffer +{ +public: + explicit RecyclableBuffer(size_t limit) + : capacity(limit) + { + /// init empty objects + for (size_t i = 0; i < limit; ++i) + { + empty_objects.push(std::make_shared()); + } + } + bool hasEmpty() const + { + assert(!isOverflow(empty_objects)); + return !empty_objects.empty(); + } + bool hasObjects() const + { + assert(!isOverflow(objects)); + return !objects.empty(); + } + bool canPushEmpty() const + { + assert(!isOverflow(empty_objects)); + return !isFull(empty_objects); + } + bool canPush() const + { + assert(!isOverflow(objects)); + return !isFull(objects); + } + + void popEmpty(std::shared_ptr & t) + { + assert(!empty_objects.empty() && !isOverflow(empty_objects)); + t = empty_objects.front(); + empty_objects.pop(); + } + void popObject(std::shared_ptr & t) + { + assert(!objects.empty() && !isOverflow(objects)); + t = objects.front(); + objects.pop(); + } + void pushObject(const std::shared_ptr & t) + { + assert(!isFullOrOverflow(objects)); + objects.push(t); + } + void pushEmpty(const std::shared_ptr & t) + { + assert(!isFullOrOverflow(empty_objects)); + empty_objects.push(t); + } + void pushEmpty(std::shared_ptr && t) + { + assert(!isFullOrOverflow(empty_objects)); + empty_objects.push(std::move(t)); + } + +private: + bool isFullOrOverflow(const std::queue> & q) const + { + return q.size() >= capacity; + } + bool isOverflow(const std::queue> & q) const + { + return q.size() > capacity; + } + bool isFull(const std::queue> & q) const + { + return q.size() == capacity; + } + + std::queue> empty_objects; + std::queue> objects; + size_t capacity; +}; + class ExchangeReceiver { public: @@ -64,14 +157,13 @@ class ExchangeReceiver tipb::ExchangeReceiver pb_exchange_receiver; size_t source_num; ::mpp::TaskMeta task_meta; + size_t max_streams; size_t max_buffer_size; std::vector workers; DAGSchema schema; - - // TODO: should be a concurrency bounded queue. std::mutex mu; std::condition_variable cv; - std::queue result_buffer; + RecyclableBuffer res_buffer; Int32 live_connections; State state; String err_msg; @@ -82,39 +174,15 @@ class ExchangeReceiver void ReadLoop(const String & meta_raw, size_t source_index); - bool decodePacket(const mpp::MPPDataPacket & p, size_t source_index, const String & req_info) - { - bool ret = true; - std::shared_ptr resp_ptr = std::make_shared(); - if (!resp_ptr->ParseFromString(p.data())) - { - resp_ptr = nullptr; - ret = false; - } - std::unique_lock lock(mu); - cv.wait(lock, [&] { return result_buffer.size() < max_buffer_size || state != NORMAL; }); - if (state == NORMAL) - { - if (resp_ptr != nullptr) - result_buffer.emplace(resp_ptr, source_index, req_info); - else - result_buffer.emplace(resp_ptr, source_index, req_info, true, "Error while decoding MPPDataPacket"); - } - else - { - ret = false; - } - cv.notify_all(); - return ret; - } - public: - ExchangeReceiver(Context & context_, const ::tipb::ExchangeReceiver & exc, const ::mpp::TaskMeta & meta, size_t max_buffer_size_, const std::shared_ptr & log_ = nullptr) + ExchangeReceiver(Context & context_, const ::tipb::ExchangeReceiver & exc, const ::mpp::TaskMeta & meta, size_t max_streams_, const std::shared_ptr & log_ = nullptr) : cluster(context_.getTMTContext().getKVCluster()) , pb_exchange_receiver(exc) , source_num(pb_exchange_receiver.encoded_task_meta_size()) , task_meta(meta) - , max_buffer_size(max_buffer_size_) + , max_streams(max_streams_) + , max_buffer_size(max_streams_ * 2) + , res_buffer(max_buffer_size) , live_connections(pb_exchange_receiver.encoded_task_meta_size()) , state(NORMAL) { @@ -127,6 +195,7 @@ class ExchangeReceiver schema.push_back(std::make_pair(name, info)); } + setUpConnection(); } @@ -137,6 +206,7 @@ class ExchangeReceiver state = CLOSED; cv.notify_all(); } + for (auto & worker : workers) { worker.join(); @@ -152,36 +222,7 @@ class ExchangeReceiver const DAGSchema & getOutputSchema() const { return schema; } - ExchangeReceiverResult nextResult() - { - std::unique_lock lk(mu); - cv.wait(lk, [&] { return !result_buffer.empty() || live_connections == 0 || state != NORMAL; }); - ExchangeReceiverResult result; - if (state != NORMAL) - { - String msg; - if (state == CANCELED) - msg = "Query canceled"; - else if (state == CLOSED) - msg = "ExchangeReceiver closed"; - else if (!err_msg.empty()) - msg = err_msg; - else - msg = "Unknown error"; - result = {nullptr, 0, "ExchangeReceiver", true, msg, false}; - } - else if (result_buffer.empty()) - { - result = {nullptr, 0, "ExchangeReceiver", false, "", true}; - } - else - { - result = result_buffer.front(); - result_buffer.pop(); - } - cv.notify_all(); - return result; - } + ExchangeReceiverResult nextResult(); size_t getSourceNum() { return source_num; } String getName() { return "ExchangeReceiver"; }