Skip to content

Commit

Permalink
Fix a multithreading bug in grpc ClientCall (#5196)
Browse files Browse the repository at this point in the history
  • Loading branch information
raulchen authored Jul 15, 2019
1 parent 5b13a7e commit 7342117
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 50 deletions.
95 changes: 60 additions & 35 deletions src/ray/rpc/client_call.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,6 @@ namespace rpc {

/// Represents an outgoing gRPC request.
///
/// The lifecycle of a `ClientCall` is as follows.
///
/// When a client submits a new gRPC request, a new `ClientCall` object will be created
/// by `ClientCallMangager::CreateCall`. Then the object will be used as the tag of
/// `CompletionQueue`.
///
/// When the reply is received, `ClientCallMangager` will get the address of this object
/// via `CompletionQueue`'s tag. And the manager should call `OnReplyReceived` and then
/// delete this object.
///
/// NOTE(hchen): Compared to `ClientCallImpl`, this abstract interface doesn't use
/// template. This allows the users (e.g., `ClientCallMangager`) not having to use
/// template as well.
Expand All @@ -38,32 +28,33 @@ class ClientCall {

class ClientCallManager;

/// Reprents the client callback function of a particular rpc method.
/// Represents the client callback function of a particular rpc method.
///
/// \tparam Reply Type of the reply message.
template <class Reply>
using ClientCallback = std::function<void(const Status &status, const Reply &reply)>;

/// Implementaion of the `ClientCall`. It represents a `ClientCall` for a particular
/// Implementation of the `ClientCall`. It represents a `ClientCall` for a particular
/// RPC method.
///
/// \tparam Reply Type of the Reply message.
template <class Reply>
class ClientCallImpl : public ClientCall {
public:
/// Constructor.
///
/// \param[in] callback The callback function to handle the reply.
explicit ClientCallImpl(const ClientCallback<Reply> &callback) : callback_(callback) {}

Status GetStatus() override { return GrpcStatusToRayStatus(status_); }

void OnReplyReceived() override {
if (callback_ != nullptr) {
callback_(GrpcStatusToRayStatus(status_), reply_);
}
}

private:
/// Constructor.
///
/// \param[in] callback The callback function to handle the reply.
ClientCallImpl(const ClientCallback<Reply> &callback) : callback_(callback) {}

/// The reply message.
Reply reply_;

Expand All @@ -83,7 +74,32 @@ class ClientCallImpl : public ClientCall {
friend class ClientCallManager;
};

/// Peprents the generic signature of a `FooService::Stub::PrepareAsyncBar`
/// This class wraps a `ClientCall`, and is used as the `tag` of gRPC's `CompletionQueue`.
///
/// The lifecycle of a `ClientCallTag` is as follows.
///
/// When a client submits a new gRPC request, a new `ClientCallTag` object will be created
/// by `ClientCallMangager::CreateCall`. Then the object will be used as the tag of
/// `CompletionQueue`.
///
/// When the reply is received, `ClientCallMangager` will get the address of this object
/// via `CompletionQueue`'s tag. And the manager should call
/// `GetCall()->OnReplyReceived()` and then delete this object.
class ClientCallTag {
public:
/// Constructor.
///
/// \param call A `ClientCall` that represents a request.
explicit ClientCallTag(std::shared_ptr<ClientCall> call) : call_(std::move(call)) {}

/// Get the wrapped `ClientCall`.
const std::shared_ptr<ClientCall> &GetCall() const { return call_; }

private:
std::shared_ptr<ClientCall> call_;
};

/// Represents the generic signature of a `FooService::Stub::PrepareAsyncBar`
/// function, where `Foo` is the service name and `Bar` is the rpc method name.
///
/// \tparam GrpcService Type of the gRPC-generated service class.
Expand All @@ -100,14 +116,15 @@ using PrepareAsyncFunction = std::unique_ptr<grpc::ClientAsyncResponseReader<Rep
/// It maintains a thread that keeps polling events from `CompletionQueue`, and post
/// the callback function to the main event loop when a reply is received.
///
/// Mutiple clients can share one `ClientCallManager`.
/// Multiple clients can share one `ClientCallManager`.
class ClientCallManager {
public:
/// Constructor.
///
/// \param[in] main_service The main event loop, to which the callback functions will be
/// posted.
ClientCallManager(boost::asio::io_service &main_service) : main_service_(main_service) {
explicit ClientCallManager(boost::asio::io_service &main_service)
: main_service_(main_service) {
// Start the polling thread.
std::thread polling_thread(&ClientCallManager::PollEventsFromCompletionQueue, this);
polling_thread.detach();
Expand All @@ -117,50 +134,58 @@ class ClientCallManager {

/// Create a new `ClientCall` and send request.
///
/// \tparam GrpcService Type of the gRPC-generated service class.
/// \tparam Request Type of the request message.
/// \tparam Reply Type of the reply message.
///
/// \param[in] stub The gRPC-generated stub.
/// \param[in] prepare_async_function Pointer to the gRPC-generated
/// `FooService::Stub::PrepareAsyncBar` function.
/// \param[in] request The request message.
/// \param[in] callback The callback function that handles reply.
///
/// \tparam GrpcService Type of the gRPC-generated service class.
/// \tparam Request Type of the request message.
/// \tparam Reply Type of the reply message.
/// \return A `ClientCall` representing the request that was just sent.
template <class GrpcService, class Request, class Reply>
ClientCall *CreateCall(
std::shared_ptr<ClientCall> CreateCall(
typename GrpcService::Stub &stub,
const PrepareAsyncFunction<GrpcService, Request, Reply> prepare_async_function,
const Request &request, const ClientCallback<Reply> &callback) {
// Create a new `ClientCall` object. This object will eventuall be deleted in the
// `ClientCallManager::PollEventsFromCompletionQueue` when reply is received.
auto call = new ClientCallImpl<Reply>(callback);
auto call = std::make_shared<ClientCallImpl<Reply>>(callback);
// Send request.
call->response_reader_ =
(stub.*prepare_async_function)(&call->context_, request, &cq_);
call->response_reader_->StartCall();
call->response_reader_->Finish(&call->reply_, &call->status_, (void *)call);
// Create a new tag object. This object will eventually be deleted in the
// `ClientCallManager::PollEventsFromCompletionQueue` when reply is received.
//
// NOTE(chen): Unlike `ServerCall`, we can't directly use `ClientCall` as the tag.
// Because this function must return a `shared_ptr` to make sure the returned
// `ClientCall` is safe to use. But `response_reader_->Finish` only accepts a raw
// pointer.
auto tag = new ClientCallTag(call);
call->response_reader_->Finish(&call->reply_, &call->status_, (void *)tag);
return call;
}

private:
/// This function runs in a background thread. It keeps polling events from the
/// `CompletionQueue`, and dispaches the event to the callbacks via the `ClientCall`
/// `CompletionQueue`, and dispatches the event to the callbacks via the `ClientCall`
/// objects.
void PollEventsFromCompletionQueue() {
void *got_tag;
bool ok = false;
// Keep reading events from the `CompletionQueue` until it's shutdown.
while (cq_.Next(&got_tag, &ok)) {
auto *call = reinterpret_cast<ClientCall *>(got_tag);
auto tag = reinterpret_cast<ClientCallTag *>(got_tag);
if (ok) {
// Post the callback to the main event loop.
main_service_.post([call]() {
call->OnReplyReceived();
// The call is finished, we can delete the `ClientCall` object now.
delete call;
main_service_.post([tag]() {
tag->GetCall()->OnReplyReceived();
// The call is finished, and we can delete this tag now.
delete tag;
});
} else {
delete call;
delete tag;
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/ray/rpc/grpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ void GrpcServer::Run() {
// TODO(hchen): Add options for authentication.
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_);
// Register all the services to this server.
if (services_.size() == 0) {
if (services_.empty()) {
RAY_LOG(WARNING) << "No service is found when start grpc server " << name_;
}
for (auto &entry : services_) {
Expand Down
10 changes: 6 additions & 4 deletions src/ray/rpc/grpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define RAY_RPC_GRPC_SERVER_H

#include <thread>
#include <utility>

#include <grpcpp/grpcpp.h>
#include <boost/asio.hpp>
Expand Down Expand Up @@ -32,8 +33,8 @@ class GrpcServer {
/// will be chosen.
/// \param[in] main_service The main event loop, to which service handler functions
/// will be posted.
GrpcServer(const std::string &name, const uint32_t port)
: name_(name), port_(port), is_closed_(true) {}
GrpcServer(std::string name, const uint32_t port)
: name_(std::move(name)), port_(port), is_closed_(true) {}

/// Destruct this gRPC server.
~GrpcServer() { Shutdown(); }
Expand Down Expand Up @@ -98,10 +99,11 @@ class GrpcService {
///
/// \param[in] main_service The main event loop, to which service handler functions
/// will be posted.
GrpcService(boost::asio::io_service &main_service) : main_service_(main_service) {}
explicit GrpcService(boost::asio::io_service &main_service)
: main_service_(main_service) {}

/// Destruct this gRPC service.
~GrpcService() {}
~GrpcService() = default;

protected:
/// Return the underlying grpc::Service object for this class.
Expand Down
19 changes: 9 additions & 10 deletions src/ray/rpc/server_call.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ enum class ServerCallState {

class ServerCallFactory;

/// Reprensents an incoming request of a gRPC server.
/// Represents an incoming request of a gRPC server.
///
/// The lifecycle and state transition of a `ServerCall` is as follows:
///
Expand Down Expand Up @@ -79,10 +79,10 @@ class ServerCall {
class ServerCallFactory {
public:
/// Create a new `ServerCall` and request gRPC runtime to start accepting the
/// corresonding type of requests.
/// corresponding type of requests.
///
/// \return Pointer to the `ServerCall` object.
virtual ServerCall *CreateCall() const = 0;
virtual void CreateCall() const = 0;

virtual ~ServerCallFactory() = default;
};
Expand Down Expand Up @@ -145,7 +145,7 @@ class ServerCallImpl : public ServerCall {
[this](Status status, std::function<void()> success,
std::function<void()> failure) {
// These two callbacks must be set before `SendReply`, because `SendReply`
// is aysnc and this `ServerCall` might be deleted right after `SendReply`.
// is async and this `ServerCall` might be deleted right after `SendReply`.
send_reply_success_callback_ = std::move(success);
send_reply_failure_callback_ = std::move(failure);

Expand All @@ -158,14 +158,14 @@ class ServerCallImpl : public ServerCall {

const ServerCallFactory &GetFactory() const override { return factory_; }

void OnReplySent() {
void OnReplySent() override {
if (send_reply_success_callback_ && !io_service_.stopped()) {
auto callback = std::move(send_reply_success_callback_);
io_service_.post([callback]() { callback(); });
}
}

void OnReplyFailed() {
void OnReplyFailed() override {
if (send_reply_failure_callback_ && !io_service_.stopped()) {
auto callback = std::move(send_reply_failure_callback_);
io_service_.post([callback]() { callback(); });
Expand All @@ -174,7 +174,7 @@ class ServerCallImpl : public ServerCall {

private:
/// Tell gRPC to finish this request and send reply asynchronously.
void SendReply(Status status) {
void SendReply(const Status &status) {
state_ = ServerCallState::SENDING_REPLY;
response_writer_.Finish(reply_, RayStatusToGrpcStatus(status), this);
}
Expand All @@ -195,7 +195,7 @@ class ServerCallImpl : public ServerCall {
/// of compression, authentication, as well as to send metadata back to the client.
grpc::ServerContext context_;

/// The reponse writer.
/// The response writer.
grpc::ServerAsyncResponseWriter<Reply> response_writer_;

/// The event loop.
Expand Down Expand Up @@ -261,7 +261,7 @@ class ServerCallFactoryImpl : public ServerCallFactory {
cq_(cq),
io_service_(io_service) {}

ServerCall *CreateCall() const override {
void CreateCall() const override {
// Create a new `ServerCall`. This object will eventually be deleted by
// `GrpcServer::PollEventsFromCompletionQueue`.
auto call = new ServerCallImpl<ServiceHandler, Request, Reply>(
Expand All @@ -271,7 +271,6 @@ class ServerCallFactoryImpl : public ServerCallFactory {
(service_.*request_call_function_)(&call->context_, &call->request_,
&call->response_writer_, cq_.get(), cq_.get(),
call);
return call;
}

private:
Expand Down

0 comments on commit 7342117

Please sign in to comment.