diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 808eeb6fd2110..226a8fb6d2516 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -101,7 +101,8 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, gcs_client_->raylet_task_table(), gcs_client_->raylet_task_table(), config.max_lineage_size), actor_registry_(), - node_manager_server_(config.node_manager_port, io_service, *this), + node_manager_server_("NodeManager", config.node_manager_port), + node_manager_service_(io_service, *this), client_call_manager_(io_service) { RAY_CHECK(heartbeat_period_.count() > 0); // Initialize the resource map with own cluster resource configuration. @@ -119,6 +120,7 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, RAY_ARROW_CHECK_OK(store_client_.Connect(config.store_socket_name.c_str())); // Run the node manger rpc server. + node_manager_server_.RegisterService(node_manager_service_); node_manager_server_.Run(); } diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index f45c8b0355536..7e812183657cf 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -512,7 +512,10 @@ class NodeManager : public rpc::NodeManagerServiceHandler { std::unordered_map checkpoint_id_to_restore_; /// The RPC server. - rpc::NodeManagerServer node_manager_server_; + rpc::GrpcServer node_manager_server_; + + /// The RPC service. + rpc::NodeManagerGrpcService node_manager_service_; /// The `ClientCallManager` object that is shared by all `NodeManagerClient`s. rpc::ClientCallManager client_call_manager_; diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index feb788da76923..f507039990c28 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -1,4 +1,5 @@ #include "ray/rpc/grpc_server.h" +#include namespace ray { namespace rpc { @@ -9,8 +10,10 @@ void GrpcServer::Run() { grpc::ServerBuilder builder; // TODO(hchen): Add options for authentication. builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_); - // Allow subclasses to register concrete services. - RegisterServices(builder); + // Register all the services to this server. + for (auto &entry : services_) { + builder.RegisterService(&entry.get()); + } // Get hold of the completion queue used for the asynchronous communication // with the gRPC runtime. cq_ = builder.AddCompletionQueue(); @@ -18,8 +21,7 @@ void GrpcServer::Run() { server_ = builder.BuildAndStart(); RAY_LOG(DEBUG) << name_ << " server started, listening on port " << port_ << "."; - // Allow subclasses to initialize the server call factories. - InitServerCallFactories(&server_call_factories_and_concurrencies_); + // Create calls for all the server call factories. for (auto &entry : server_call_factories_and_concurrencies_) { for (int i = 0; i < entry.second; i++) { // Create and request calls from the factory. @@ -31,6 +33,11 @@ void GrpcServer::Run() { polling_thread.detach(); } +void GrpcServer::RegisterService(GrpcService &service) { + services_.emplace_back(service.GetGrpcService()); + service.InitServerCallFactories(cq_, &server_call_factories_and_concurrencies_); +} + void GrpcServer::PollEventsFromCompletionQueue() { void *tag; bool ok; @@ -48,7 +55,7 @@ void GrpcServer::PollEventsFromCompletionQueue() { // incoming request. server_call->GetFactory().CreateCall(); server_call->SetState(ServerCallState::PROCESSING); - main_service_.post([server_call] { server_call->HandleRequest(); }); + server_call->HandleRequest(); break; case ServerCallState::SENDING_REPLY: // The reply has been sent, this call can be deleted now. diff --git a/src/ray/rpc/grpc_server.h b/src/ray/rpc/grpc_server.h index 4953f470610fc..584da6565a47a 100644 --- a/src/ray/rpc/grpc_server.h +++ b/src/ray/rpc/grpc_server.h @@ -12,7 +12,9 @@ namespace ray { namespace rpc { -/// Base class that represents an abstract gRPC server. +class GrpcService; + +/// Class that represents an gRPC server. /// /// A `GrpcServer` listens on a specific port. It owns /// 1) a `ServerCompletionQueue` that is used for polling events from gRPC, @@ -28,11 +30,7 @@ class GrpcServer { /// \param[in] name Name of this server, used for logging and debugging purpose. /// \param[in] port The port to bind this server to. If it's 0, a random available port /// will be chosen. - /// \param[in] main_service The main event loop, to which service handler functions - /// will be posted. - GrpcServer(const std::string &name, const uint32_t port, - boost::asio::io_service &main_service) - : name_(name), port_(port), main_service_(main_service) {} + GrpcServer(const std::string &name, const uint32_t port) : name_(name), port_(port) {} /// Destruct this gRPC server. ~GrpcServer() { @@ -46,36 +44,25 @@ class GrpcServer { /// Get the port of this gRPC server. int GetPort() const { return port_; } - protected: - /// Subclasses should implement this method and register one or multiple gRPC services - /// to the given `ServerBuilder`. + /// Register a grpc service. Multiple services can be registered to the same server. + /// Note that the `service` registered must remain valid for the lifetime of the + /// `GrpcServer`, as it holds the underlying `grpc::Service`. /// - /// \param[in] builder The `ServerBuilder` instance to register services to. - virtual void RegisterServices(grpc::ServerBuilder &builder) = 0; - - /// Subclasses should implement this method to initialize the `ServerCallFactory` - /// instances, as well as specify maximum number of concurrent requests that gRPC - /// server can "accept" (not "handle"). Each factory will be used to create - /// `accept_concurrency` `ServerCall` objects, each of which will be used to accept and - /// handle an incoming request. - /// - /// \param[out] server_call_factories_and_concurrencies The `ServerCallFactory` objects, - /// and the maximum number of concurrent requests that gRPC server can accept. - virtual void InitServerCallFactories( - std::vector, int>> - *server_call_factories_and_concurrencies) = 0; + /// \param[in] service A `GrpcService` to register to this server. + void RegisterService(GrpcService &service); + protected: /// This function runs in a background thread. It keeps polling events from the /// `ServerCompletionQueue`, and dispaches the event to the `ServiceHandler` instances /// via the `ServerCall` objects. void PollEventsFromCompletionQueue(); - /// The main event loop, to which the service handler functions will be posted. - boost::asio::io_service &main_service_; /// Name of this server, used for logging and debugging purpose. const std::string name_; /// Port of this server. int port_; + /// The `grpc::Service` objects which should be registered to `ServerBuilder`. + std::vector> services_; /// The `ServerCallFactory` objects, and the maximum number of concurrent requests that /// gRPC server can accept. std::vector, int>> @@ -86,6 +73,46 @@ class GrpcServer { std::unique_ptr server_; }; +/// Base class that represents an abstract gRPC service. +/// +/// Subclass should implement `InitServerCallFactories` to decide +/// which kinds of requests this service should accept. +class GrpcService { + public: + /// Constructor. + /// + /// \param[in] main_service The main event loop, to which service handler functions + /// will be posted. + GrpcService(boost::asio::io_service &main_service) : main_service_(main_service) {} + + /// Destruct this gRPC service. + ~GrpcService() {} + + protected: + /// Return the underlying grpc::Service object for this class. + /// This is passed to `GrpcServer` to be registered to grpc `ServerBuilder`. + virtual grpc::Service &GetGrpcService() = 0; + + /// Subclasses should implement this method to initialize the `ServerCallFactory` + /// instances, as well as specify maximum number of concurrent requests that gRPC + /// server can "accept" (not "handle"). Each factory will be used to create + /// `accept_concurrency` `ServerCall` objects, each of which will be used to accept and + /// handle an incoming request. + /// + /// \param[in] cq The grpc completion queue. + /// \param[out] server_call_factories_and_concurrencies The `ServerCallFactory` objects, + /// and the maximum number of concurrent requests that gRPC server can accept. + virtual void InitServerCallFactories( + const std::unique_ptr &cq, + std::vector, int>> + *server_call_factories_and_concurrencies) = 0; + + /// The main event loop, to which the service handler functions will be posted. + boost::asio::io_service &main_service_; + + friend class GrpcServer; +}; + } // namespace rpc } // namespace ray diff --git a/src/ray/rpc/node_manager_server.h b/src/ray/rpc/node_manager_server.h index afaea299ea891..d05f268c65b24 100644 --- a/src/ray/rpc/node_manager_server.h +++ b/src/ray/rpc/node_manager_server.h @@ -25,25 +25,22 @@ class NodeManagerServiceHandler { RequestDoneCallback done_callback) = 0; }; -/// The `GrpcServer` for `NodeManagerService`. -class NodeManagerServer : public GrpcServer { +/// The `GrpcService` for `NodeManagerService`. +class NodeManagerGrpcService : public GrpcService { public: /// Constructor. /// - /// \param[in] port See super class. - /// \param[in] main_service See super class. + /// \param[in] io_service See super class. /// \param[in] handler The service handler that actually handle the requests. - NodeManagerServer(const uint32_t port, boost::asio::io_service &main_service, - NodeManagerServiceHandler &service_handler) - : GrpcServer("NodeManager", port, main_service), - service_handler_(service_handler){}; + NodeManagerGrpcService(boost::asio::io_service &io_service, + NodeManagerServiceHandler &service_handler) + : GrpcService(io_service), service_handler_(service_handler){}; - void RegisterServices(grpc::ServerBuilder &builder) override { - /// Register `NodeManagerService`. - builder.RegisterService(&service_); - } + protected: + grpc::Service &GetGrpcService() override { return service_; } void InitServerCallFactories( + const std::unique_ptr &cq, std::vector, int>> *server_call_factories_and_concurrencies) override { // Initialize the factory for `ForwardTask` requests. @@ -51,7 +48,8 @@ class NodeManagerServer : public GrpcServer { new ServerCallFactoryImpl( service_, &NodeManagerService::AsyncService::RequestForwardTask, - service_handler_, &NodeManagerServiceHandler::HandleForwardTask, cq_)); + service_handler_, &NodeManagerServiceHandler::HandleForwardTask, cq, + main_service_)); // Set `ForwardTask`'s accept concurrency to 100. server_call_factories_and_concurrencies->emplace_back( @@ -61,6 +59,7 @@ class NodeManagerServer : public GrpcServer { private: /// The grpc async service object. NodeManagerService::AsyncService service_; + /// The service handler that actually handle the requests. NodeManagerServiceHandler &service_handler_; }; diff --git a/src/ray/rpc/server_call.h b/src/ray/rpc/server_call.h index e06278260ab67..08ca128323ee3 100644 --- a/src/ray/rpc/server_call.h +++ b/src/ray/rpc/server_call.h @@ -94,20 +94,27 @@ class ServerCallImpl : public ServerCall { /// \param[in] factory The factory which created this call. /// \param[in] service_handler The service handler that handles the request. /// \param[in] handle_request_function Pointer to the service handler function. + /// \param[in] io_service The event loop. ServerCallImpl( const ServerCallFactory &factory, ServiceHandler &service_handler, - HandleRequestFunction handle_request_function) + HandleRequestFunction handle_request_function, + boost::asio::io_service &io_service) : state_(ServerCallState::PENDING), factory_(factory), service_handler_(service_handler), handle_request_function_(handle_request_function), - response_writer_(&context_) {} + response_writer_(&context_), + io_service_(io_service) {} ServerCallState GetState() const override { return state_; } void SetState(const ServerCallState &new_state) override { state_ = new_state; } void HandleRequest() override { + io_service_.post([this] { HandleRequestImpl(); }); + } + + void HandleRequestImpl() { state_ = ServerCallState::PROCESSING; (service_handler_.*handle_request_function_)(request_, &reply_, [this](Status status) { @@ -146,6 +153,9 @@ class ServerCallImpl : public ServerCall { /// The reponse writer. grpc::ServerAsyncResponseWriter response_writer_; + /// The event loop. + boost::asio::io_service &io_service_; + /// The request message. Request request_; @@ -185,23 +195,26 @@ class ServerCallFactoryImpl : public ServerCallFactory { /// \param[in] service_handler The service handler that handles the request. /// \param[in] handle_request_function Pointer to the service handler function. /// \param[in] cq The `CompletionQueue`. + /// \param[in] io_service The event loop. ServerCallFactoryImpl( AsyncService &service, RequestCallFunction request_call_function, ServiceHandler &service_handler, HandleRequestFunction handle_request_function, - const std::unique_ptr &cq) + const std::unique_ptr &cq, + boost::asio::io_service &io_service) : service_(service), request_call_function_(request_call_function), service_handler_(service_handler), handle_request_function_(handle_request_function), - cq_(cq) {} + cq_(cq), + io_service_(io_service) {} ServerCall *CreateCall() const override { // Create a new `ServerCall`. This object will eventually be deleted by // `GrpcServer::PollEventsFromCompletionQueue`. auto call = new ServerCallImpl( - *this, service_handler_, handle_request_function_); + *this, service_handler_, handle_request_function_, io_service_); /// Request gRPC runtime to starting accepting this kind of request, using the call as /// the tag. (service_.*request_call_function_)(&call->context_, &call->request_, @@ -225,6 +238,9 @@ class ServerCallFactoryImpl : public ServerCallFactory { /// The `CompletionQueue`. const std::unique_ptr &cq_; + + /// The event loop. + boost::asio::io_service &io_service_; }; } // namespace rpc