Skip to content

Commit

Permalink
Support shutdown write and notify for success write
Browse files Browse the repository at this point in the history
  • Loading branch information
chenBright committed Mar 16, 2024
1 parent 24fc31e commit f436e4d
Show file tree
Hide file tree
Showing 5 changed files with 402 additions and 36 deletions.
1 change: 1 addition & 0 deletions src/brpc/controller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ BAIDU_REGISTER_ERRNO(brpc::ELOGOFF, "Server is stopping");
BAIDU_REGISTER_ERRNO(brpc::ELIMIT, "Reached server's max_concurrency");
BAIDU_REGISTER_ERRNO(brpc::ECLOSE, "Close socket initiatively");
BAIDU_REGISTER_ERRNO(brpc::EITP, "Bad Itp response");
BAIDU_REGISTER_ERRNO(brpc::ESHUTDOWNWRITE, "Shutdown write of socket");

#if BRPC_WITH_RDMA
BAIDU_REGISTER_ERRNO(brpc::ERDMA, "RDMA verbs error");
Expand Down
1 change: 1 addition & 0 deletions src/brpc/errno.proto
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ enum Errno {
ELIMIT = 2004; // Reached server's limit on resources
ECLOSE = 2005; // Close socket initiatively
EITP = 2006; // Failed Itp response
ESHUTDOWNWRITE = 2007; // Shutdown write of socket

// Errno related to RDMA (may happen at both sides)
ERDMA = 3001; // RDMA verbs error
Expand Down
130 changes: 101 additions & 29 deletions src/brpc/socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,33 +307,56 @@ const uint32_t MAX_PIPELINED_COUNT = 16384;

struct BAIDU_CACHELINE_ALIGNMENT Socket::WriteRequest {
static WriteRequest* const UNCONNECTED;

butil::IOBuf data;
WriteRequest* next;
bthread_id_t id_wait;
Socket* socket;

void clear_and_set_control_bits(bool notify_on_success,
bool shutdown_write) {
_socket_and_control_bits.set_extra(
(uint16_t)notify_on_success << 1 | (uint16_t)shutdown_write);
}

void set_socket(Socket* s) {
_socket_and_control_bits.set(s);
}

// If this field is set to true, notify when write successfully.
bool is_notify_on_success() const {
return _socket_and_control_bits.extra() & ((uint16_t)1 << 1);
}

// Whether shutdown write of the socket after this write complete.
bool need_shutdown_write() const {
return _socket_and_control_bits.extra() & (uint16_t)1;
}

Socket* get_socket() const {
return _socket_and_control_bits.get();
}

uint32_t pipelined_count() const {
return (_pc_and_udmsg >> 48) & 0x3FFF;
return _pc_and_udmsg.extra() & 0x3FFF;
}
uint32_t get_auth_flags() const {
return (_pc_and_udmsg >> 62) & 0x03;
return (_pc_and_udmsg.extra() >> 14) & 0x03;
}
void clear_pipelined_count_and_auth_flags() {
_pc_and_udmsg &= 0xFFFFFFFFFFFFULL;
_pc_and_udmsg.reset_ptr_and_extra();
}
SocketMessage* user_message() const {
return (SocketMessage*)(_pc_and_udmsg & 0xFFFFFFFFFFFFULL);
return _pc_and_udmsg.get();
}
void clear_user_message() {
_pc_and_udmsg &= 0xFFFF000000000000ULL;
_pc_and_udmsg.reset();
}
void set_pipelined_count_and_user_message(
uint32_t pc, SocketMessage* msg, uint32_t auth_flags) {
if (auth_flags) {
pc |= (auth_flags & 0x03) << 14;
}
_pc_and_udmsg = ((uint64_t)pc << 48) | (uint64_t)(uintptr_t)msg;
_pc_and_udmsg.set_ptr_and_extra(msg, pc);
}

bool reset_pipelined_count_and_user_message() {
Expand All @@ -355,7 +378,10 @@ struct BAIDU_CACHELINE_ALIGNMENT Socket::WriteRequest {
void Setup(Socket* s);

private:
uint64_t _pc_and_udmsg;
// Socket pointer and some control bits.
PackedPtr<Socket> _socket_and_control_bits;
// User message pointer, pipelined count auth flag.
PackedPtr<SocketMessage> _pc_and_udmsg;
};

void Socket::WriteRequest::Setup(Socket* s) {
Expand Down Expand Up @@ -399,7 +425,7 @@ class Socket::EpollOutRequest : public SocketUser {
EpollOutRequest() : fd(-1), timer_id(0)
, on_epollout_event(NULL), data(NULL) {}

~EpollOutRequest() {
~EpollOutRequest() override {
// Remove the timer at last inside destructor to avoid
// race with the place that registers the timer
if (timer_id) {
Expand All @@ -408,8 +434,8 @@ class Socket::EpollOutRequest : public SocketUser {
}
}

void BeforeRecycle(Socket*) {
// Recycle itself
void BeforeRecycle(Socket*) override {
// Recycle itself.
delete this;
}

Expand Down Expand Up @@ -464,6 +490,7 @@ Socket::Socket(Forbidden)
, _unwritten_bytes(0)
, _epollout_butex(NULL)
, _write_head(NULL)
, _is_wirte_shutdown(false)
, _stream_set(NULL)
, _total_streams_unconsumed_size(0)
, _ninflight_app_health_check(0)
Expand All @@ -485,7 +512,11 @@ void Socket::ReturnSuccessfulWriteRequest(Socket::WriteRequest* p) {
const bthread_id_t id_wait = p->id_wait;
butil::return_object(p);
if (id_wait != INVALID_BTHREAD_ID) {
NotifyOnFailed(id_wait);
if (p->is_notify_on_success() && !Failed()) {
bthread_id_error(id_wait, 0);
} else {
NotifyOnFailed(id_wait);
}
}
}

Expand Down Expand Up @@ -514,11 +545,18 @@ Socket::WriteRequest* Socket::ReleaseWriteRequestsExceptLast(
}

void Socket::ReleaseAllFailedWriteRequests(Socket::WriteRequest* req) {
CHECK(Failed());
pthread_mutex_lock(&_id_wait_list_mutex);
const int error_code = non_zero_error_code();
const std::string error_text = _error_text;
pthread_mutex_unlock(&_id_wait_list_mutex);
CHECK(Failed() || IsWriteShutdown());
int error_code;
std::string error_text;
if (Failed()) {
pthread_mutex_lock(&_id_wait_list_mutex);
error_code = non_zero_error_code();
error_text = _error_text;
pthread_mutex_unlock(&_id_wait_list_mutex);
} else {
error_code = ESHUTDOWNWRITE;
error_text = "Shutdown write of the socket";
}
// Notice that `req' is not tail if Address after IsWriteComplete fails.
do {
req = ReleaseWriteRequestsExceptLast(req, error_code, error_text);
Expand Down Expand Up @@ -746,6 +784,7 @@ int Socket::Create(const SocketOptions& options, SocketId* id) {
m->_keepalive_options = options.keepalive_options;
m->_bthread_tag = options.bthread_tag;
CHECK(NULL == m->_write_head.load(butil::memory_order_relaxed));
m->_is_wirte_shutdown = false;
// Must be last one! Internal fields of this Socket may be access
// just after calling ResetFileDescriptor.
if (m->ResetFileDescriptor(options.fd) != 0) {
Expand Down Expand Up @@ -1385,7 +1424,7 @@ int Socket::ConnectIfNot(const timespec* abstime, WriteRequest* req) {
// Have to hold a reference for `req'
SocketUniquePtr s;
ReAddress(&s);
req->socket = s.get();
req->set_socket(s.get());
if (_conn) {
if (_conn->Connect(this, abstime, KeepWriteIfConnected, req) < 0) {
return -1;
Expand Down Expand Up @@ -1457,7 +1496,7 @@ int Socket::HandleEpollOutRequest(int error_code, EpollOutRequest* req) {
void Socket::AfterAppConnected(int err, void* data) {
WriteRequest* req = static_cast<WriteRequest*>(data);
if (err == 0) {
Socket* const s = req->socket;
Socket* const s = req->get_socket();
SharedPart* sp = s->GetSharedPart();
if (sp) {
sp->num_continuous_connect_timeouts.store(0, butil::memory_order_relaxed);
Expand All @@ -1471,7 +1510,7 @@ void Socket::AfterAppConnected(int err, void* data) {
KeepWrite(req);
}
} else {
SocketUniquePtr s(req->socket);
SocketUniquePtr s(req->get_socket());
if (err == ETIMEDOUT) {
SharedPart* sp = s->GetOrNewSharedPart();
if (sp->num_continuous_connect_timeouts.fetch_add(
Expand Down Expand Up @@ -1499,7 +1538,7 @@ static void* RunClosure(void* arg) {

int Socket::KeepWriteIfConnected(int fd, int err, void* data) {
WriteRequest* req = static_cast<WriteRequest*>(data);
Socket* s = req->socket;
Socket* s = req->get_socket();
if (err == 0 && s->ssl_state() == SSL_CONNECTING) {
// Run ssl connect in a new bthread to avoid blocking
// the current bthread (thus blocking the EventDispatcher)
Expand All @@ -1522,12 +1561,13 @@ int Socket::KeepWriteIfConnected(int fd, int err, void* data) {
void Socket::CheckConnectedAndKeepWrite(int fd, int err, void* data) {
butil::fd_guard sockfd(fd);
WriteRequest* req = static_cast<WriteRequest*>(data);
Socket* s = req->socket;
Socket* s = req->get_socket();
CHECK_GE(sockfd, 0);
if (err == 0 && s->CheckConnected(sockfd) == 0
&& s->ResetFileDescriptor(sockfd) == 0) {
if (s->_app_connect) {
s->_app_connect->StartConnect(req->socket, AfterAppConnected, req);
s->_app_connect->StartConnect(req->get_socket(),
AfterAppConnected, req);
} else {
// Successfully created a connection
AfterAppConnected(0, req);
Expand Down Expand Up @@ -1614,6 +1654,7 @@ int Socket::Write(butil::IOBuf* data, const WriteOptions* options_in) {
// wait until it points to a valid WriteRequest or NULL.
req->next = WriteRequest::UNCONNECTED;
req->id_wait = opt.id_wait;
req->clear_and_set_control_bits(opt.notify_on_success, opt.shutdown_write);
req->set_pipelined_count_and_user_message(
opt.pipelined_count, DUMMY_USER_MESSAGE, opt.auth_flags);
return StartWrite(req, opt);
Expand Down Expand Up @@ -1650,7 +1691,9 @@ int Socket::Write(SocketMessagePtr<>& msg, const WriteOptions* options_in) {
// wait until it points to a valid WriteRequest or NULL.
req->next = WriteRequest::UNCONNECTED;
req->id_wait = opt.id_wait;
req->set_pipelined_count_and_user_message(opt.pipelined_count, msg.release(), opt.auth_flags);
req->clear_and_set_control_bits(opt.notify_on_success, opt.shutdown_write);
req->set_pipelined_count_and_user_message(
opt.pipelined_count, msg.release(), opt.auth_flags);
return StartWrite(req, opt);
}

Expand All @@ -1672,12 +1715,19 @@ int Socket::StartWrite(WriteRequest* req, const WriteOptions& opt) {
bthread_t th;
SocketUniquePtr ptr_for_keep_write;
ssize_t nw = 0;
int ret = 0;

// We've got the right to write.
req->next = NULL;

// Fast fail when write has been shutdown.
if (_is_wirte_shutdown) {
goto FAIL_TO_WRITE;
}
_is_wirte_shutdown = req->need_shutdown_write();

// Connect to remote_side() if not.
int ret = ConnectIfNot(opt.abstime, req);
ret = ConnectIfNot(opt.abstime, req);
if (ret < 0) {
saved_errno = errno;
SetFailed(errno, "Fail to connect %s directly: %m", description().c_str());
Expand Down Expand Up @@ -1736,7 +1786,7 @@ int Socket::StartWrite(WriteRequest* req, const WriteOptions& opt) {

KEEPWRITE_IN_BACKGROUND:
ReAddress(&ptr_for_keep_write);
req->socket = ptr_for_keep_write.release();
req->set_socket(ptr_for_keep_write.release());
if (bthread_start_background(&th, &BTHREAD_ATTR_NORMAL,
KeepWrite, req) != 0) {
LOG(FATAL) << "Fail to start KeepWrite";
Expand All @@ -1758,19 +1808,26 @@ static const size_t DATA_LIST_MAX = 256;
void* Socket::KeepWrite(void* void_arg) {
g_vars->nkeepwrite << 1;
WriteRequest* req = static_cast<WriteRequest*>(void_arg);
SocketUniquePtr s(req->socket);
SocketUniquePtr s(req->get_socket());

// When error occurs, spin until there's no more requests instead of
// returning directly otherwise _write_head is permantly non-NULL which
// makes later Write() abnormal.
WriteRequest* cur_tail = NULL;
do {
// req was written, skip it.
bool need_shutdown = false;
if (req->next != NULL && req->data.empty()) {
WriteRequest* const saved_req = req;
need_shutdown = req->need_shutdown_write();
req = req->next;
s->ReturnSuccessfulWriteRequest(saved_req);
}
if (need_shutdown) {
LOG(WARNING) << "Shutdown write of " << *s;
break;
}

const ssize_t nw = s->DoWrite(req);
if (nw < 0) {
if (errno != EAGAIN && errno != EOVERCROWDED) {
Expand All @@ -1783,11 +1840,19 @@ void* Socket::KeepWrite(void* void_arg) {
} else {
s->AddOutputBytes(nw);
}
// Release WriteRequest until non-empty data or last request.
// Release WriteRequest until non-empty data, last request or shutdown write.
while (req->next != NULL && req->data.empty()) {
WriteRequest* const saved_req = req;
need_shutdown = req->need_shutdown_write();
req = req->next;
s->ReturnSuccessfulWriteRequest(saved_req);
if (need_shutdown) {
break;
}
}
if (need_shutdown) {
LOG(WARNING) << "Shutdown write of " << *s;
break;
}
// TODO(gejun): wait for epollout when we actually have written
// all the data. This weird heuristic reduces 30us delay...
Expand Down Expand Up @@ -1867,6 +1932,11 @@ ssize_t Socket::DoWrite(WriteRequest* req) {
for (WriteRequest* p = req; p != NULL && ndata < DATA_LIST_MAX;
p = p->next) {
data_list[ndata++] = &p->data;
if (p->need_shutdown_write()) {
// Write WriteRequest until shutdown write.
_is_wirte_shutdown = true;
break;
}
}

if (ssl_state() == SSL_OFF) {
Expand Down Expand Up @@ -2387,6 +2457,8 @@ void Socket::DebugSocket(std::ostream& os, SocketId id) {
os << "\n}";
}

os << "\nis_wirte_shutdown=" << ptr->_is_wirte_shutdown;

{
int keepalive = 0;
socklen_t len = sizeof(keepalive);
Expand Down
Loading

0 comments on commit f436e4d

Please sign in to comment.