Skip to content

Commit

Permalink
[grpc] refactor rpc server to support multiple io services (#5023)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhijunfu authored and pcmoritz committed Jun 26, 2019
1 parent aa5fc52 commit bb8e75b
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 50 deletions.
4 changes: 3 additions & 1 deletion src/ray/raylet/node_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ NodeManager::NodeManager(boost::asio::io_service &io_service,
gcs_client_->raylet_task_table(), gcs_client_->raylet_task_table(),
config.max_lineage_size),
actor_registry_(),
node_manager_server_(config.node_manager_port, io_service, *this),
node_manager_server_("NodeManager", config.node_manager_port),
node_manager_service_(io_service, *this),
client_call_manager_(io_service) {
RAY_CHECK(heartbeat_period_.count() > 0);
// Initialize the resource map with own cluster resource configuration.
Expand All @@ -119,6 +120,7 @@ NodeManager::NodeManager(boost::asio::io_service &io_service,

RAY_ARROW_CHECK_OK(store_client_.Connect(config.store_socket_name.c_str()));
// Run the node manger rpc server.
node_manager_server_.RegisterService(node_manager_service_);
node_manager_server_.Run();
}

Expand Down
5 changes: 4 additions & 1 deletion src/ray/raylet/node_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,10 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
std::unordered_map<ActorID, ActorCheckpointID> checkpoint_id_to_restore_;

/// The RPC server.
rpc::NodeManagerServer node_manager_server_;
rpc::GrpcServer node_manager_server_;

/// The RPC service.
rpc::NodeManagerGrpcService node_manager_service_;

/// The `ClientCallManager` object that is shared by all `NodeManagerClient`s.
rpc::ClientCallManager client_call_manager_;
Expand Down
17 changes: 12 additions & 5 deletions src/ray/rpc/grpc_server.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "ray/rpc/grpc_server.h"
#include <grpcpp/impl/service_type.h>

namespace ray {
namespace rpc {
Expand All @@ -9,17 +10,18 @@ void GrpcServer::Run() {
grpc::ServerBuilder builder;
// TODO(hchen): Add options for authentication.
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_);
// Allow subclasses to register concrete services.
RegisterServices(builder);
// Register all the services to this server.
for (auto &entry : services_) {
builder.RegisterService(&entry.get());
}
// Get hold of the completion queue used for the asynchronous communication
// with the gRPC runtime.
cq_ = builder.AddCompletionQueue();
// Build and start server.
server_ = builder.BuildAndStart();
RAY_LOG(DEBUG) << name_ << " server started, listening on port " << port_ << ".";

// Allow subclasses to initialize the server call factories.
InitServerCallFactories(&server_call_factories_and_concurrencies_);
// Create calls for all the server call factories.
for (auto &entry : server_call_factories_and_concurrencies_) {
for (int i = 0; i < entry.second; i++) {
// Create and request calls from the factory.
Expand All @@ -31,6 +33,11 @@ void GrpcServer::Run() {
polling_thread.detach();
}

void GrpcServer::RegisterService(GrpcService &service) {
services_.emplace_back(service.GetGrpcService());
service.InitServerCallFactories(cq_, &server_call_factories_and_concurrencies_);
}

void GrpcServer::PollEventsFromCompletionQueue() {
void *tag;
bool ok;
Expand All @@ -48,7 +55,7 @@ void GrpcServer::PollEventsFromCompletionQueue() {
// incoming request.
server_call->GetFactory().CreateCall();
server_call->SetState(ServerCallState::PROCESSING);
main_service_.post([server_call] { server_call->HandleRequest(); });
server_call->HandleRequest();
break;
case ServerCallState::SENDING_REPLY:
// The reply has been sent, this call can be deleted now.
Expand Down
77 changes: 52 additions & 25 deletions src/ray/rpc/grpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
namespace ray {
namespace rpc {

/// Base class that represents an abstract gRPC server.
class GrpcService;

/// Class that represents an gRPC server.
///
/// A `GrpcServer` listens on a specific port. It owns
/// 1) a `ServerCompletionQueue` that is used for polling events from gRPC,
Expand All @@ -28,11 +30,7 @@ class GrpcServer {
/// \param[in] name Name of this server, used for logging and debugging purpose.
/// \param[in] port The port to bind this server to. If it's 0, a random available port
/// will be chosen.
/// \param[in] main_service The main event loop, to which service handler functions
/// will be posted.
GrpcServer(const std::string &name, const uint32_t port,
boost::asio::io_service &main_service)
: name_(name), port_(port), main_service_(main_service) {}
GrpcServer(const std::string &name, const uint32_t port) : name_(name), port_(port) {}

/// Destruct this gRPC server.
~GrpcServer() {
Expand All @@ -46,36 +44,25 @@ class GrpcServer {
/// Get the port of this gRPC server.
int GetPort() const { return port_; }

protected:
/// Subclasses should implement this method and register one or multiple gRPC services
/// to the given `ServerBuilder`.
/// Register a grpc service. Multiple services can be registered to the same server.
/// Note that the `service` registered must remain valid for the lifetime of the
/// `GrpcServer`, as it holds the underlying `grpc::Service`.
///
/// \param[in] builder The `ServerBuilder` instance to register services to.
virtual void RegisterServices(grpc::ServerBuilder &builder) = 0;

/// Subclasses should implement this method to initialize the `ServerCallFactory`
/// instances, as well as specify maximum number of concurrent requests that gRPC
/// server can "accept" (not "handle"). Each factory will be used to create
/// `accept_concurrency` `ServerCall` objects, each of which will be used to accept and
/// handle an incoming request.
///
/// \param[out] server_call_factories_and_concurrencies The `ServerCallFactory` objects,
/// and the maximum number of concurrent requests that gRPC server can accept.
virtual void InitServerCallFactories(
std::vector<std::pair<std::unique_ptr<ServerCallFactory>, int>>
*server_call_factories_and_concurrencies) = 0;
/// \param[in] service A `GrpcService` to register to this server.
void RegisterService(GrpcService &service);

protected:
/// This function runs in a background thread. It keeps polling events from the
/// `ServerCompletionQueue`, and dispaches the event to the `ServiceHandler` instances
/// via the `ServerCall` objects.
void PollEventsFromCompletionQueue();

/// The main event loop, to which the service handler functions will be posted.
boost::asio::io_service &main_service_;
/// Name of this server, used for logging and debugging purpose.
const std::string name_;
/// Port of this server.
int port_;
/// The `grpc::Service` objects which should be registered to `ServerBuilder`.
std::vector<std::reference_wrapper<grpc::Service>> services_;
/// The `ServerCallFactory` objects, and the maximum number of concurrent requests that
/// gRPC server can accept.
std::vector<std::pair<std::unique_ptr<ServerCallFactory>, int>>
Expand All @@ -86,6 +73,46 @@ class GrpcServer {
std::unique_ptr<grpc::Server> server_;
};

/// Base class that represents an abstract gRPC service.
///
/// Subclass should implement `InitServerCallFactories` to decide
/// which kinds of requests this service should accept.
class GrpcService {
public:
/// Constructor.
///
/// \param[in] main_service The main event loop, to which service handler functions
/// will be posted.
GrpcService(boost::asio::io_service &main_service) : main_service_(main_service) {}

/// Destruct this gRPC service.
~GrpcService() {}

protected:
/// Return the underlying grpc::Service object for this class.
/// This is passed to `GrpcServer` to be registered to grpc `ServerBuilder`.
virtual grpc::Service &GetGrpcService() = 0;

/// Subclasses should implement this method to initialize the `ServerCallFactory`
/// instances, as well as specify maximum number of concurrent requests that gRPC
/// server can "accept" (not "handle"). Each factory will be used to create
/// `accept_concurrency` `ServerCall` objects, each of which will be used to accept and
/// handle an incoming request.
///
/// \param[in] cq The grpc completion queue.
/// \param[out] server_call_factories_and_concurrencies The `ServerCallFactory` objects,
/// and the maximum number of concurrent requests that gRPC server can accept.
virtual void InitServerCallFactories(
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
std::vector<std::pair<std::unique_ptr<ServerCallFactory>, int>>
*server_call_factories_and_concurrencies) = 0;

/// The main event loop, to which the service handler functions will be posted.
boost::asio::io_service &main_service_;

friend class GrpcServer;
};

} // namespace rpc
} // namespace ray

Expand Down
25 changes: 12 additions & 13 deletions src/ray/rpc/node_manager_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,33 +25,31 @@ class NodeManagerServiceHandler {
RequestDoneCallback done_callback) = 0;
};

/// The `GrpcServer` for `NodeManagerService`.
class NodeManagerServer : public GrpcServer {
/// The `GrpcService` for `NodeManagerService`.
class NodeManagerGrpcService : public GrpcService {
public:
/// Constructor.
///
/// \param[in] port See super class.
/// \param[in] main_service See super class.
/// \param[in] io_service See super class.
/// \param[in] handler The service handler that actually handle the requests.
NodeManagerServer(const uint32_t port, boost::asio::io_service &main_service,
NodeManagerServiceHandler &service_handler)
: GrpcServer("NodeManager", port, main_service),
service_handler_(service_handler){};
NodeManagerGrpcService(boost::asio::io_service &io_service,
NodeManagerServiceHandler &service_handler)
: GrpcService(io_service), service_handler_(service_handler){};

void RegisterServices(grpc::ServerBuilder &builder) override {
/// Register `NodeManagerService`.
builder.RegisterService(&service_);
}
protected:
grpc::Service &GetGrpcService() override { return service_; }

void InitServerCallFactories(
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
std::vector<std::pair<std::unique_ptr<ServerCallFactory>, int>>
*server_call_factories_and_concurrencies) override {
// Initialize the factory for `ForwardTask` requests.
std::unique_ptr<ServerCallFactory> forward_task_call_factory(
new ServerCallFactoryImpl<NodeManagerService, NodeManagerServiceHandler,
ForwardTaskRequest, ForwardTaskReply>(
service_, &NodeManagerService::AsyncService::RequestForwardTask,
service_handler_, &NodeManagerServiceHandler::HandleForwardTask, cq_));
service_handler_, &NodeManagerServiceHandler::HandleForwardTask, cq,
main_service_));

// Set `ForwardTask`'s accept concurrency to 100.
server_call_factories_and_concurrencies->emplace_back(
Expand All @@ -61,6 +59,7 @@ class NodeManagerServer : public GrpcServer {
private:
/// The grpc async service object.
NodeManagerService::AsyncService service_;

/// The service handler that actually handle the requests.
NodeManagerServiceHandler &service_handler_;
};
Expand Down
26 changes: 21 additions & 5 deletions src/ray/rpc/server_call.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,20 +94,27 @@ class ServerCallImpl : public ServerCall {
/// \param[in] factory The factory which created this call.
/// \param[in] service_handler The service handler that handles the request.
/// \param[in] handle_request_function Pointer to the service handler function.
/// \param[in] io_service The event loop.
ServerCallImpl(
const ServerCallFactory &factory, ServiceHandler &service_handler,
HandleRequestFunction<ServiceHandler, Request, Reply> handle_request_function)
HandleRequestFunction<ServiceHandler, Request, Reply> handle_request_function,
boost::asio::io_service &io_service)
: state_(ServerCallState::PENDING),
factory_(factory),
service_handler_(service_handler),
handle_request_function_(handle_request_function),
response_writer_(&context_) {}
response_writer_(&context_),
io_service_(io_service) {}

ServerCallState GetState() const override { return state_; }

void SetState(const ServerCallState &new_state) override { state_ = new_state; }

void HandleRequest() override {
io_service_.post([this] { HandleRequestImpl(); });
}

void HandleRequestImpl() {
state_ = ServerCallState::PROCESSING;
(service_handler_.*handle_request_function_)(request_, &reply_,
[this](Status status) {
Expand Down Expand Up @@ -146,6 +153,9 @@ class ServerCallImpl : public ServerCall {
/// The reponse writer.
grpc::ServerAsyncResponseWriter<Reply> response_writer_;

/// The event loop.
boost::asio::io_service &io_service_;

/// The request message.
Request request_;

Expand Down Expand Up @@ -185,23 +195,26 @@ class ServerCallFactoryImpl : public ServerCallFactory {
/// \param[in] service_handler The service handler that handles the request.
/// \param[in] handle_request_function Pointer to the service handler function.
/// \param[in] cq The `CompletionQueue`.
/// \param[in] io_service The event loop.
ServerCallFactoryImpl(
AsyncService &service,
RequestCallFunction<GrpcService, Request, Reply> request_call_function,
ServiceHandler &service_handler,
HandleRequestFunction<ServiceHandler, Request, Reply> handle_request_function,
const std::unique_ptr<grpc::ServerCompletionQueue> &cq)
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
boost::asio::io_service &io_service)
: service_(service),
request_call_function_(request_call_function),
service_handler_(service_handler),
handle_request_function_(handle_request_function),
cq_(cq) {}
cq_(cq),
io_service_(io_service) {}

ServerCall *CreateCall() const override {
// Create a new `ServerCall`. This object will eventually be deleted by
// `GrpcServer::PollEventsFromCompletionQueue`.
auto call = new ServerCallImpl<ServiceHandler, Request, Reply>(
*this, service_handler_, handle_request_function_);
*this, service_handler_, handle_request_function_, io_service_);
/// Request gRPC runtime to starting accepting this kind of request, using the call as
/// the tag.
(service_.*request_call_function_)(&call->context_, &call->request_,
Expand All @@ -225,6 +238,9 @@ class ServerCallFactoryImpl : public ServerCallFactory {

/// The `CompletionQueue`.
const std::unique_ptr<grpc::ServerCompletionQueue> &cq_;

/// The event loop.
boost::asio::io_service &io_service_;
};

} // namespace rpc
Expand Down

0 comments on commit bb8e75b

Please sign in to comment.