Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fleet_executor] interceptor send message through message_bus #37106

Merged
merged 3 commits into from
Nov 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/distributed/fleet_executor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc

if(WITH_DISTRIBUTE)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(message_bus.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(message_bus.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(carrier.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
Expand Down
5 changes: 0 additions & 5 deletions paddle/fluid/distributed/fleet_executor/fleet_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,5 @@ std::shared_ptr<Carrier> FleetExecutor::GetCarrier() {
return nullptr;
}

std::shared_ptr<MessageBus> FleetExecutor::GetMessageBus() {
// get message bus
return nullptr;
}

} // namespace distributed
} // namespace paddle
2 changes: 0 additions & 2 deletions paddle/fluid/distributed/fleet_executor/fleet_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,12 @@ class FleetExecutor final {
void Run();
void Release();
static std::shared_ptr<Carrier> GetCarrier();
static std::shared_ptr<MessageBus> GetMessageBus();

private:
DISABLE_COPY_AND_ASSIGN(FleetExecutor);
FleetExecutorDesc exe_desc_;
std::unique_ptr<RuntimeGraph> runtime_graph_;
static std::shared_ptr<Carrier> global_carrier_;
static std::shared_ptr<MessageBus> global_message_bus_;
};

} // namespace distributed
Expand Down
7 changes: 3 additions & 4 deletions paddle/fluid/distributed/fleet_executor/interceptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"

namespace paddle {
namespace distributed {
Expand All @@ -27,9 +28,7 @@ Interceptor::Interceptor(int64_t interceptor_id, TaskNode* node)

Interceptor::~Interceptor() { interceptor_thread_.join(); }

void Interceptor::RegisterInterceptorHandle(InterceptorHandle handle) {
handle_ = handle;
}
void Interceptor::RegisterMsgHandle(MsgHandle handle) { handle_ = handle; }

void Interceptor::Handle(const InterceptorMessage& msg) {
if (handle_) {
Expand Down Expand Up @@ -61,7 +60,7 @@ void Interceptor::Send(int64_t dst_id,
std::unique_ptr<InterceptorMessage> msg) {
msg->set_src_id(interceptor_id_);
msg->set_dst_id(dst_id);
// send interceptor msg
MessageBus::Instance().Send(*msg.get());
}

void Interceptor::PoolTheMailbox() {
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/distributed/fleet_executor/interceptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class TaskNode;

class Interceptor {
public:
using InterceptorHandle = std::function<void(const InterceptorMessage&)>;
using MsgHandle = std::function<void(const InterceptorMessage&)>;

public:
Interceptor() = delete;
Expand All @@ -44,7 +44,7 @@ class Interceptor {
virtual ~Interceptor();

// register interceptor handle
void RegisterInterceptorHandle(InterceptorHandle handle);
void RegisterMsgHandle(MsgHandle handle);

void Handle(const InterceptorMessage& msg);

Expand Down Expand Up @@ -77,7 +77,7 @@ class Interceptor {
TaskNode* node_;

// interceptor handle which process message
InterceptorHandle handle_{nullptr};
MsgHandle handle_{nullptr};

// mutex to control read/write conflict for remote mailbox
std::mutex remote_mailbox_mutex_;
Expand Down
20 changes: 14 additions & 6 deletions paddle/fluid/distributed/fleet_executor/message_bus.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,28 @@
namespace paddle {
namespace distributed {

MessageBus::MessageBus(
void MessageBus::Init(
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_map<int64_t, std::string>& rank_to_addr,
const std::string& addr)
: interceptor_id_to_rank_(interceptor_id_to_rank),
rank_to_addr_(rank_to_addr),
addr_(addr) {
const std::string& addr) {
PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists(
"MessageBus is already init."));
is_init_ = true;
interceptor_id_to_rank_ = interceptor_id_to_rank;
rank_to_addr_ = rank_to_addr;
addr_ = addr;

listen_port_thread_ = std::thread([this]() {
VLOG(3) << "Start listen_port_thread_ for message bus";
ListenPort();
});

std::call_once(once_flag_, []() {
std::atexit([]() { MessageBus::Instance().Release(); });
});
}

MessageBus::~MessageBus() {
void MessageBus::Release() {
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
server_.Stop(1000);
Expand Down
20 changes: 15 additions & 5 deletions paddle/fluid/distributed/fleet_executor/message_bus.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#pragma once

#include <mutex>
#include <string>
#include <thread>
#include <unordered_map>
Expand All @@ -35,22 +36,28 @@ namespace distributed {

class Carrier;

// A singleton MessageBus
class MessageBus final {
public:
MessageBus() = delete;
static MessageBus& Instance() {
static MessageBus msg_bus;
return msg_bus;
}

MessageBus(const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_map<int64_t, std::string>& rank_to_addr,
const std::string& addr);
void Init(const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_map<int64_t, std::string>& rank_to_addr,
const std::string& addr);

~MessageBus();
void Release();

// called by Interceptor, send InterceptorMessage to dst
bool Send(const InterceptorMessage& interceptor_message);

DISABLE_COPY_AND_ASSIGN(MessageBus);

private:
MessageBus() = default;

// function keep listen the port and handle the message
void ListenPort();

Expand All @@ -66,6 +73,9 @@ class MessageBus final {
// send the message intra rank (dst is the same rank with src)
bool SendIntraRank(const InterceptorMessage& interceptor_message);

bool is_init_{false};
std::once_flag once_flag_;

// handed by above layer, save the info mapping interceptor id to rank id
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_;

Expand Down