Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move tunnel_map to MPPTunnelSet #5123

Merged
merged 4 commits into from
Jun 10, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 14 additions & 31 deletions dbms/src/Flash/Mpp/MPPTask.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,34 +78,28 @@ MPPTask::~MPPTask()

void MPPTask::closeAllTunnels(const String & reason)
{
for (auto & it : tunnel_map)
{
it.second->close(reason);
}
if (likely(tunnel_set))
tunnel_set->close(reason);
}

void MPPTask::finishWrite()
{
for (const auto & it : tunnel_map)
{
it.second->writeDone();
}
RUNTIME_ASSERT(tunnel_set != nullptr, log, "mpp task without tunnel set");
tunnel_set->finishWrite();
}

void MPPTask::run()
{
newThreadManager()->scheduleThenDetach(true, "MPPTask", [self = shared_from_this()] { self->runImpl(); });
}

void MPPTask::registerTunnel(const MPPTaskId & id, MPPTunnelPtr tunnel)
void MPPTask::registerTunnel(const MPPTaskId & task_id, MPPTunnelPtr tunnel)
{
if (status == CANCELLED)
throw Exception("the tunnel " + tunnel->id() + " can not been registered, because the task is cancelled");

if (tunnel_map.find(id) != tunnel_map.end())
throw Exception("the tunnel " + tunnel->id() + " has been registered");

tunnel_map[id] = tunnel;
RUNTIME_ASSERT(tunnel_set != nullptr, log, "mpp task without tunnel set");
tunnel_set->registerTunnel(task_id, tunnel);
}

std::pair<MPPTunnelPtr, String> MPPTask::getTunnel(const ::mpp::EstablishMPPConnectionRequest * request)
Expand All @@ -120,16 +114,17 @@ std::pair<MPPTunnelPtr, String> MPPTask::getTunnel(const ::mpp::EstablishMPPConn
}

MPPTaskId receiver_id{request->receiver_meta().start_ts(), request->receiver_meta().task_id()};
auto it = tunnel_map.find(receiver_id);
if (it == tunnel_map.end())
RUNTIME_ASSERT(tunnel_set != nullptr, log, "mpp task without tunnel set");
auto tunnel_ptr = tunnel_set->getTunnelById(receiver_id);
if (tunnel_ptr == nullptr)
{
auto err_msg = fmt::format(
"can't find tunnel ({} + {})",
request->sender_meta().task_id(),
request->receiver_meta().task_id());
return {nullptr, err_msg};
}
return {it->second, ""};
return {tunnel_ptr, ""};
}

void MPPTask::unregisterTask()
Expand Down Expand Up @@ -211,7 +206,7 @@ void MPPTask::prepare(const mpp::DispatchTaskRequest & task_request)
}

// register tunnels
tunnel_set = std::make_shared<MPPTunnelSet>();
tunnel_set = std::make_shared<MPPTunnelSet>(log->identifier());
std::chrono::seconds timeout(task_request.timeout());

for (int i = 0; i < exchange_sender.encoded_task_meta_size(); i++)
Expand All @@ -225,7 +220,6 @@ void MPPTask::prepare(const mpp::DispatchTaskRequest & task_request)
MPPTunnelPtr tunnel = std::make_shared<MPPTunnel>(task_meta, task_request.meta(), timeout, context->getSettingsRef().max_threads, is_local, is_async, log->identifier());
LOG_FMT_DEBUG(log, "begin to register the tunnel {}", tunnel->id());
registerTunnel(MPPTaskId{task_meta.start_ts(), task_meta.task_id()}, tunnel);
tunnel_set->addTunnel(tunnel);
if (!dag_context->isRootMPPTask())
{
FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::exception_during_mpp_register_tunnel_for_non_root_mpp_task);
Expand Down Expand Up @@ -369,19 +363,8 @@ void MPPTask::runImpl()

void MPPTask::writeErrToAllTunnels(const String & e)
{
for (auto & it : tunnel_map)
{
try
{
FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::exception_during_mpp_write_err_to_tunnel);
it.second->write(getPacketWithError(e), true);
}
catch (...)
{
it.second->close("Failed to write error msg to tunnel");
tryLogCurrentException(log, "Failed to write error " + e + " to tunnel: " + it.second->id());
}
}
RUNTIME_ASSERT(tunnel_set != nullptr, log, "mpp task without tunnel set");
tunnel_set->writeError(e);
}

void MPPTask::cancel(const String & reason)
Expand Down
3 changes: 0 additions & 3 deletions dbms/src/Flash/Mpp/MPPTask.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,6 @@ class MPPTask : public std::enable_shared_from_this<MPPTask>

MPPTunnelSetPtr tunnel_set;

// which targeted task we should send data by which tunnel.
std::unordered_map<MPPTaskId, MPPTunnelPtr> tunnel_map;

MPPTaskManager * manager = nullptr;

const LoggerPtr log;
Expand Down
51 changes: 51 additions & 0 deletions dbms/src/Flash/Mpp/MPPTunnelSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,17 @@
// limitations under the License.

#include <Common/Exception.h>
#include <Common/FailPoint.h>
#include <Flash/Mpp/MPPTunnelSet.h>
#include <Flash/Mpp/Utils.h>
#include <fmt/core.h>

namespace DB
{
namespace FailPoints
{
extern const char exception_during_mpp_write_err_to_tunnel[];
} // namespace FailPoints
namespace
{
inline mpp::MPPDataPacket serializeToPacket(const tipb::SelectResponse & response)
Expand Down Expand Up @@ -108,6 +114,51 @@ void MPPTunnelSetBase<Tunnel>::write(mpp::MPPDataPacket & packet, int16_t partit
tunnels[partition_id]->write(packet);
}

template <typename Tunnel>
void MPPTunnelSetBase<Tunnel>::writeError(const String & msg)
{
for (auto & tunnel : tunnels)
{
try
{
FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::exception_during_mpp_write_err_to_tunnel);
tunnel->write(getPacketWithError(msg), true);
}
catch (...)
{
tryLogCurrentException(log, "Failed to write error " + msg + " to tunnel: " + tunnel->id());
tunnel->close("Failed to write error msg to tunnel");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in original code, tunnel->close is the 1st statement, are they same?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep the original implementation.

}
}
}

template <typename Tunnel>
void MPPTunnelSetBase<Tunnel>::close(const String & reason)
{
for (auto & tunnel : tunnels)
tunnel->close(reason);
}

template <typename Tunnel>
void MPPTunnelSetBase<Tunnel>::finishWrite()
{
for (auto & tunnel : tunnels)
{
tunnel->writeDone();
}
}

template <typename Tunnel>
typename MPPTunnelSetBase<Tunnel>::TunnelPtr MPPTunnelSetBase<Tunnel>::getTunnelById(const MPPTaskId & id)
{
auto it = id_to_index_map.find(id);
if (it == id_to_index_map.end())
{
return nullptr;
}
return tunnels[it->second];
}

/// Explicit template instantiations - to avoid code bloat in headers.
template class MPPTunnelSetBase<MPPTunnel>;

Expand Down
16 changes: 15 additions & 1 deletion dbms/src/Flash/Mpp/MPPTunnelSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#pragma once

#include <Flash/Mpp/MPPTaskId.h>
#include <Flash/Mpp/MPPTunnel.h>
#ifdef __clang__
#pragma clang diagnostic push
Expand All @@ -32,6 +33,9 @@ class MPPTunnelSetBase : private boost::noncopyable
{
public:
using TunnelPtr = std::shared_ptr<Tunnel>;
explicit MPPTunnelSetBase(const String & req_id)
: log(Logger::get("MPPTunnelSet", req_id))
{}

void clearExecutionSummaries(tipb::SelectResponse & response);

Expand All @@ -50,11 +54,19 @@ class MPPTunnelSetBase : private boost::noncopyable
// this is a partition writing.
void write(tipb::SelectResponse & response, int16_t partition_id);
void write(mpp::MPPDataPacket & packet, int16_t partition_id);
void writeError(const String & msg);
void close(const String & reason);
void finishWrite();
TunnelPtr getTunnelById(const MPPTaskId & id);

uint16_t getPartitionNum() const { return tunnels.size(); }

void addTunnel(const TunnelPtr & tunnel)
void registerTunnel(const MPPTaskId & id, const TunnelPtr & tunnel)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about hide the implementation in .cpp file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

{
if (id_to_index_map.find(id) != id_to_index_map.end())
throw Exception("the tunnel " + tunnel->id() + " has been registered");

id_to_index_map[id] = tunnels.size();
tunnels.push_back(tunnel);
if (!tunnel->isLocal())
{
Expand All @@ -71,6 +83,8 @@ class MPPTunnelSetBase : private boost::noncopyable

private:
std::vector<TunnelPtr> tunnels;
std::unordered_map<MPPTaskId, size_t> id_to_index_map;
bestwoody marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
std::unordered_map<MPPTaskId, size_t> id_to_index_map;
std::unordered_map<MPPTaskId, size_t> target_id_to_index_map;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, but the pr was already merged :(

const LoggerPtr log;

int remote_tunnel_cnt = 0;
};
Expand Down