diff --git a/src/brpc/controller.cpp b/src/brpc/controller.cpp index bfe278ffb9..80adbd37e0 100644 --- a/src/brpc/controller.cpp +++ b/src/brpc/controller.cpp @@ -80,6 +80,7 @@ BAIDU_REGISTER_ERRNO(brpc::ELOGOFF, "Server is stopping"); BAIDU_REGISTER_ERRNO(brpc::ELIMIT, "Reached server's max_concurrency"); BAIDU_REGISTER_ERRNO(brpc::ECLOSE, "Close socket initiatively"); BAIDU_REGISTER_ERRNO(brpc::EITP, "Bad Itp response"); +BAIDU_REGISTER_ERRNO(brpc::ESHUTDOWNWRITE, "Shutdown write of socket"); #if BRPC_WITH_RDMA BAIDU_REGISTER_ERRNO(brpc::ERDMA, "RDMA verbs error"); diff --git a/src/brpc/errno.proto b/src/brpc/errno.proto index fccd8edb8d..26ffadc201 100644 --- a/src/brpc/errno.proto +++ b/src/brpc/errno.proto @@ -49,6 +49,7 @@ enum Errno { ELIMIT = 2004; // Reached server's limit on resources ECLOSE = 2005; // Close socket initiatively EITP = 2006; // Failed Itp response + ESHUTDOWNWRITE = 2007; // Shutdown write of socket // Errno related to RDMA (may happen at both sides) ERDMA = 3001; // RDMA verbs error diff --git a/src/brpc/socket.cpp b/src/brpc/socket.cpp index 1f9b2a2e9b..554640d531 100644 --- a/src/brpc/socket.cpp +++ b/src/brpc/socket.cpp @@ -307,33 +307,56 @@ const uint32_t MAX_PIPELINED_COUNT = 16384; struct BAIDU_CACHELINE_ALIGNMENT Socket::WriteRequest { static WriteRequest* const UNCONNECTED; - + butil::IOBuf data; WriteRequest* next; bthread_id_t id_wait; - Socket* socket; + + void clear_and_set_control_bits(bool notify_on_success, + bool shutdown_write) { + _socket_and_control_bits.set_extra( + (uint16_t)notify_on_success << 1 | (uint16_t)shutdown_write); + } + + void set_socket(Socket* s) { + _socket_and_control_bits.set(s); + } + + // If this field is set to true, notify when write successfully. + bool is_notify_on_success() const { + return _socket_and_control_bits.extra() & ((uint16_t)1 << 1); + } + + // Whether shutdown write of the socket after this write complete. + bool need_shutdown_write() const { + return _socket_and_control_bits.extra() & (uint16_t)1; + } + + Socket* get_socket() const { + return _socket_and_control_bits.get(); + } uint32_t pipelined_count() const { - return (_pc_and_udmsg >> 48) & 0x3FFF; + return _pc_and_udmsg.extra() & 0x3FFF; } uint32_t get_auth_flags() const { - return (_pc_and_udmsg >> 62) & 0x03; + return (_pc_and_udmsg.extra() >> 14) & 0x03; } void clear_pipelined_count_and_auth_flags() { - _pc_and_udmsg &= 0xFFFFFFFFFFFFULL; + _pc_and_udmsg.reset_ptr_and_extra(); } SocketMessage* user_message() const { - return (SocketMessage*)(_pc_and_udmsg & 0xFFFFFFFFFFFFULL); + return _pc_and_udmsg.get(); } void clear_user_message() { - _pc_and_udmsg &= 0xFFFF000000000000ULL; + _pc_and_udmsg.reset(); } void set_pipelined_count_and_user_message( uint32_t pc, SocketMessage* msg, uint32_t auth_flags) { if (auth_flags) { pc |= (auth_flags & 0x03) << 14; } - _pc_and_udmsg = ((uint64_t)pc << 48) | (uint64_t)(uintptr_t)msg; + _pc_and_udmsg.set_ptr_and_extra(msg, pc); } bool reset_pipelined_count_and_user_message() { @@ -355,7 +378,10 @@ struct BAIDU_CACHELINE_ALIGNMENT Socket::WriteRequest { void Setup(Socket* s); private: - uint64_t _pc_and_udmsg; + // Socket pointer and some control bits. + PackedPtr _socket_and_control_bits; + // User message pointer, pipelined count auth flag. + PackedPtr _pc_and_udmsg; }; void Socket::WriteRequest::Setup(Socket* s) { @@ -399,7 +425,7 @@ class Socket::EpollOutRequest : public SocketUser { EpollOutRequest() : fd(-1), timer_id(0) , on_epollout_event(NULL), data(NULL) {} - ~EpollOutRequest() { + ~EpollOutRequest() override { // Remove the timer at last inside destructor to avoid // race with the place that registers the timer if (timer_id) { @@ -408,8 +434,8 @@ class Socket::EpollOutRequest : public SocketUser { } } - void BeforeRecycle(Socket*) { - // Recycle itself + void BeforeRecycle(Socket*) override { + // Recycle itself. delete this; } @@ -464,6 +490,7 @@ Socket::Socket(Forbidden) , _unwritten_bytes(0) , _epollout_butex(NULL) , _write_head(NULL) + , _is_wirte_shutdown(false) , _stream_set(NULL) , _total_streams_unconsumed_size(0) , _ninflight_app_health_check(0) @@ -485,7 +512,11 @@ void Socket::ReturnSuccessfulWriteRequest(Socket::WriteRequest* p) { const bthread_id_t id_wait = p->id_wait; butil::return_object(p); if (id_wait != INVALID_BTHREAD_ID) { - NotifyOnFailed(id_wait); + if (p->is_notify_on_success() && !Failed()) { + bthread_id_error(id_wait, 0); + } else { + NotifyOnFailed(id_wait); + } } } @@ -514,11 +545,18 @@ Socket::WriteRequest* Socket::ReleaseWriteRequestsExceptLast( } void Socket::ReleaseAllFailedWriteRequests(Socket::WriteRequest* req) { - CHECK(Failed()); - pthread_mutex_lock(&_id_wait_list_mutex); - const int error_code = non_zero_error_code(); - const std::string error_text = _error_text; - pthread_mutex_unlock(&_id_wait_list_mutex); + CHECK(Failed() || IsWriteShutdown()); + int error_code; + std::string error_text; + if (Failed()) { + pthread_mutex_lock(&_id_wait_list_mutex); + error_code = non_zero_error_code(); + error_text = _error_text; + pthread_mutex_unlock(&_id_wait_list_mutex); + } else { + error_code = ESHUTDOWNWRITE; + error_text = "Shutdown write of the socket"; + } // Notice that `req' is not tail if Address after IsWriteComplete fails. do { req = ReleaseWriteRequestsExceptLast(req, error_code, error_text); @@ -746,6 +784,7 @@ int Socket::Create(const SocketOptions& options, SocketId* id) { m->_keepalive_options = options.keepalive_options; m->_bthread_tag = options.bthread_tag; CHECK(NULL == m->_write_head.load(butil::memory_order_relaxed)); + m->_is_wirte_shutdown = false; // Must be last one! Internal fields of this Socket may be access // just after calling ResetFileDescriptor. if (m->ResetFileDescriptor(options.fd) != 0) { @@ -1385,7 +1424,7 @@ int Socket::ConnectIfNot(const timespec* abstime, WriteRequest* req) { // Have to hold a reference for `req' SocketUniquePtr s; ReAddress(&s); - req->socket = s.get(); + req->set_socket(s.get()); if (_conn) { if (_conn->Connect(this, abstime, KeepWriteIfConnected, req) < 0) { return -1; @@ -1457,7 +1496,7 @@ int Socket::HandleEpollOutRequest(int error_code, EpollOutRequest* req) { void Socket::AfterAppConnected(int err, void* data) { WriteRequest* req = static_cast(data); if (err == 0) { - Socket* const s = req->socket; + Socket* const s = req->get_socket(); SharedPart* sp = s->GetSharedPart(); if (sp) { sp->num_continuous_connect_timeouts.store(0, butil::memory_order_relaxed); @@ -1471,7 +1510,7 @@ void Socket::AfterAppConnected(int err, void* data) { KeepWrite(req); } } else { - SocketUniquePtr s(req->socket); + SocketUniquePtr s(req->get_socket()); if (err == ETIMEDOUT) { SharedPart* sp = s->GetOrNewSharedPart(); if (sp->num_continuous_connect_timeouts.fetch_add( @@ -1499,7 +1538,7 @@ static void* RunClosure(void* arg) { int Socket::KeepWriteIfConnected(int fd, int err, void* data) { WriteRequest* req = static_cast(data); - Socket* s = req->socket; + Socket* s = req->get_socket(); if (err == 0 && s->ssl_state() == SSL_CONNECTING) { // Run ssl connect in a new bthread to avoid blocking // the current bthread (thus blocking the EventDispatcher) @@ -1522,12 +1561,13 @@ int Socket::KeepWriteIfConnected(int fd, int err, void* data) { void Socket::CheckConnectedAndKeepWrite(int fd, int err, void* data) { butil::fd_guard sockfd(fd); WriteRequest* req = static_cast(data); - Socket* s = req->socket; + Socket* s = req->get_socket(); CHECK_GE(sockfd, 0); if (err == 0 && s->CheckConnected(sockfd) == 0 && s->ResetFileDescriptor(sockfd) == 0) { if (s->_app_connect) { - s->_app_connect->StartConnect(req->socket, AfterAppConnected, req); + s->_app_connect->StartConnect(req->get_socket(), + AfterAppConnected, req); } else { // Successfully created a connection AfterAppConnected(0, req); @@ -1614,6 +1654,7 @@ int Socket::Write(butil::IOBuf* data, const WriteOptions* options_in) { // wait until it points to a valid WriteRequest or NULL. req->next = WriteRequest::UNCONNECTED; req->id_wait = opt.id_wait; + req->clear_and_set_control_bits(opt.notify_on_success, opt.shutdown_write); req->set_pipelined_count_and_user_message( opt.pipelined_count, DUMMY_USER_MESSAGE, opt.auth_flags); return StartWrite(req, opt); @@ -1650,7 +1691,9 @@ int Socket::Write(SocketMessagePtr<>& msg, const WriteOptions* options_in) { // wait until it points to a valid WriteRequest or NULL. req->next = WriteRequest::UNCONNECTED; req->id_wait = opt.id_wait; - req->set_pipelined_count_and_user_message(opt.pipelined_count, msg.release(), opt.auth_flags); + req->clear_and_set_control_bits(opt.notify_on_success, opt.shutdown_write); + req->set_pipelined_count_and_user_message( + opt.pipelined_count, msg.release(), opt.auth_flags); return StartWrite(req, opt); } @@ -1672,12 +1715,19 @@ int Socket::StartWrite(WriteRequest* req, const WriteOptions& opt) { bthread_t th; SocketUniquePtr ptr_for_keep_write; ssize_t nw = 0; + int ret = 0; // We've got the right to write. req->next = NULL; + + // Fast fail when write has been shutdown. + if (_is_wirte_shutdown) { + goto FAIL_TO_WRITE; + } + _is_wirte_shutdown = req->need_shutdown_write(); // Connect to remote_side() if not. - int ret = ConnectIfNot(opt.abstime, req); + ret = ConnectIfNot(opt.abstime, req); if (ret < 0) { saved_errno = errno; SetFailed(errno, "Fail to connect %s directly: %m", description().c_str()); @@ -1736,7 +1786,7 @@ int Socket::StartWrite(WriteRequest* req, const WriteOptions& opt) { KEEPWRITE_IN_BACKGROUND: ReAddress(&ptr_for_keep_write); - req->socket = ptr_for_keep_write.release(); + req->set_socket(ptr_for_keep_write.release()); if (bthread_start_background(&th, &BTHREAD_ATTR_NORMAL, KeepWrite, req) != 0) { LOG(FATAL) << "Fail to start KeepWrite"; @@ -1758,7 +1808,7 @@ static const size_t DATA_LIST_MAX = 256; void* Socket::KeepWrite(void* void_arg) { g_vars->nkeepwrite << 1; WriteRequest* req = static_cast(void_arg); - SocketUniquePtr s(req->socket); + SocketUniquePtr s(req->get_socket()); // When error occurs, spin until there's no more requests instead of // returning directly otherwise _write_head is permantly non-NULL which @@ -1766,11 +1816,18 @@ void* Socket::KeepWrite(void* void_arg) { WriteRequest* cur_tail = NULL; do { // req was written, skip it. + bool need_shutdown = false; if (req->next != NULL && req->data.empty()) { WriteRequest* const saved_req = req; + need_shutdown = req->need_shutdown_write(); req = req->next; s->ReturnSuccessfulWriteRequest(saved_req); } + if (need_shutdown) { + LOG(WARNING) << "Shutdown write of " << *s; + break; + } + const ssize_t nw = s->DoWrite(req); if (nw < 0) { if (errno != EAGAIN && errno != EOVERCROWDED) { @@ -1783,11 +1840,19 @@ void* Socket::KeepWrite(void* void_arg) { } else { s->AddOutputBytes(nw); } - // Release WriteRequest until non-empty data or last request. + // Release WriteRequest until non-empty data, last request or shutdown write. while (req->next != NULL && req->data.empty()) { WriteRequest* const saved_req = req; + need_shutdown = req->need_shutdown_write(); req = req->next; s->ReturnSuccessfulWriteRequest(saved_req); + if (need_shutdown) { + break; + } + } + if (need_shutdown) { + LOG(WARNING) << "Shutdown write of " << *s; + break; } // TODO(gejun): wait for epollout when we actually have written // all the data. This weird heuristic reduces 30us delay... @@ -1867,6 +1932,11 @@ ssize_t Socket::DoWrite(WriteRequest* req) { for (WriteRequest* p = req; p != NULL && ndata < DATA_LIST_MAX; p = p->next) { data_list[ndata++] = &p->data; + if (p->need_shutdown_write()) { + // Write WriteRequest until shutdown write. + _is_wirte_shutdown = true; + break; + } } if (ssl_state() == SSL_OFF) { @@ -2387,6 +2457,8 @@ void Socket::DebugSocket(std::ostream& os, SocketId id) { os << "\n}"; } + os << "\nis_wirte_shutdown=" << ptr->_is_wirte_shutdown; + { int keepalive = 0; socklen_t len = sizeof(keepalive); diff --git a/src/brpc/socket.h b/src/brpc/socket.h index 9d85aafaff..ecb8e3b4aa 100644 --- a/src/brpc/socket.h +++ b/src/brpc/socket.h @@ -166,6 +166,62 @@ struct PipelinedInfo { bthread_id_t id_wait; }; +// A data structure packed with a pointer and +// some extra information using a uint64 variable. +template +class PackedPtr { + static constexpr uint8_t MAX_POINTER_LEN = 48; + static constexpr uint64_t POINTER_MASK = ((uint64_t)1 << MAX_POINTER_LEN) - 1; + static constexpr uint64_t EXTRA_MASK = ~POINTER_MASK; +public: + PackedPtr() : _data(0) { + BAIDU_CASSERT(sizeof(PackedPtr) == 8, sizeof_packed_ptr_must_be_8); + } + + void set(T* ptr) { + // Clear the low 48 bits and then + // store the pointer in the low 48 bits. + _data = (_data & EXTRA_MASK) | + ((uint64_t)(uintptr_t)ptr & POINTER_MASK); + } + + void reset() { + // Clear the low 48 bits. + _data &= EXTRA_MASK; + } + + T* get() const { return (T*)(_data & POINTER_MASK); } + + void set_extra(uint16_t extra) { + // Clear the high 16 bits and then + // store the extra in the high 16 bits. + _data = (_data & POINTER_MASK) | + ((uint64_t)extra << MAX_POINTER_LEN); + } + + void reset_extra() { + // Clear the high 16 bits. + _data &= POINTER_MASK; + } + + uint16_t extra() const { return _data >> MAX_POINTER_LEN; } + + void set_ptr_and_extra(T* p, uint16_t extra) { + _data = ((uint64_t)(uintptr_t)p & POINTER_MASK) | + ((uint64_t)extra << MAX_POINTER_LEN); + } + + void reset_ptr_and_extra() { + _data = 0; + } + +private: + // Pointer is stored in the low 48 bits, + // extra information is stored in the high 16 bits. + uint64_t _data; +}; + + struct SocketSSLContext { SocketSSLContext(); ~SocketSSLContext(); @@ -269,11 +325,18 @@ friend class policy::H2GlobalStreamCreator; // - Write once when uncontended(most cases). // - Wait-free when contended. struct WriteOptions { - // `id_wait' is signalled when this Socket is SetFailed. To disable - // the signal, set this field to INVALID_BTHREAD_ID. - // `on_reset' of `id_wait' is NOT called when Write() returns non-zero. + // `id_wait' is signalled when this Socket is SetFailed or data is written + // successfully with `notify_on_success=true'. To disable the signal, set + // this field to INVALID_BTHREAD_ID. `on_reset' of `id_wait' is NOT called + // when Write() returns non-zero. // Default: INVALID_BTHREAD_ID bthread_id_t id_wait; + + // If this field is set to true and `id_wait' is not INVALID_BTHREAD_ID, + // `id_wait' can be signalled when write successfully. + // Default: false + bool notify_on_success; + // If no connection exists, a connection will be established to // remote_side() regarding deadline `abstime'. NULL means no timeout. // Default: NULL @@ -301,13 +364,27 @@ friend class policy::H2GlobalStreamCreator; // performance. Otherwise, each write only writes one `msg` into socket // and no KeepWrite thread can be created, which brings poor // performance. + // Default: false bool write_in_background; + // After this write complete, shutdown write of the socket. + // Default: false + bool shutdown_write; + WriteOptions() - : id_wait(INVALID_BTHREAD_ID), abstime(NULL) - , pipelined_count(0), auth_flags(0) - , ignore_eovercrowded(false), write_in_background(false) {} + : id_wait(INVALID_BTHREAD_ID) + , notify_on_success(false) + , abstime(NULL) + , pipelined_count(0) + , auth_flags(0) + , ignore_eovercrowded(false) + , write_in_background(false) + , shutdown_write(false) {} }; + + // True if write of socket is shutdown. + bool IsWriteShutdown() const { return _is_wirte_shutdown; } + int Write(butil::IOBuf *msg, const WriteOptions* options = NULL); // Write an user-defined message. `msg' is released when Write() is @@ -917,6 +994,8 @@ friend void DereferenceSocket(Socket*); // Storing data that are not flushed into `fd' yet. butil::atomic _write_head; + bool _is_wirte_shutdown; + butil::Mutex _stream_mutex; std::set *_stream_set; butil::atomic _total_streams_unconsumed_size; diff --git a/test/brpc_socket_unittest.cpp b/test/brpc_socket_unittest.cpp index d225873531..f278c46b06 100644 --- a/test/brpc_socket_unittest.cpp +++ b/test/brpc_socket_unittest.cpp @@ -1225,7 +1225,6 @@ TEST_F(SocketTest, keepalive) { } } - TEST_F(SocketTest, keepalive_input_message) { int default_keepalive = 0; int default_keepalive_idle = 0; @@ -1418,3 +1417,217 @@ TEST_F(SocketTest, keepalive_input_message) { sockfd.release(); } } + +int HandleSocketSuccessWrite(bthread_id_t id, void* data, int error_code, + const std::string& error_text) { + auto success_count = static_cast(data); + EXPECT_NE(nullptr, success_count); + EXPECT_EQ(0, error_code); + ++(*success_count); + CHECK_EQ(0, bthread_id_unlock_and_destroy(id)); + return 0; +} + +TEST_F(SocketTest, notify_on_success) { + const size_t REP = 10000; + int fds[2]; + ASSERT_EQ(0, socketpair(AF_UNIX, SOCK_STREAM, 0, fds)); + + brpc::SocketId id = 8888; + butil::EndPoint dummy; + ASSERT_EQ(0, str2endpoint("192.168.1.26:8080", &dummy)); + brpc::SocketOptions options; + options.fd = fds[1]; + options.remote_side = dummy; + options.user = new CheckRecycle; + ASSERT_EQ(0, brpc::Socket::Create(options, &id)); + brpc::SocketUniquePtr s; + ASSERT_EQ(0, brpc::Socket::Address(id, &s)); + s->_ssl_state = brpc::SSL_OFF; + ASSERT_EQ(2, brpc::NRefOfVRef(s->_versioned_ref)); + global_sock = s.get(); + ASSERT_TRUE(s.get()); + ASSERT_EQ(fds[1], s->fd()); + ASSERT_EQ(dummy, s->remote_side()); + ASSERT_EQ(id, s->id()); + + pthread_t rth; + ReaderArg reader_arg = { fds[0], 0 }; + pthread_create(&rth, NULL, reader, &reader_arg); + + size_t success_count = 0; + char buf[] = "hello reader side!"; + for (size_t c = 0; c < REP; ++c) { + bthread_id_t write_id; + ASSERT_EQ(0, bthread_id_create2(&write_id, &success_count, + HandleSocketSuccessWrite)); + brpc::Socket::WriteOptions wopt; + wopt.id_wait = write_id; + wopt.notify_on_success = true; + butil::IOBuf src; + src.append(buf, 16); + if (s->Write(&src, &wopt) != 0) { + if (errno == brpc::EOVERCROWDED) { + // The buf is full, sleep a while and retry. + bthread_usleep(1000); + --c; + continue; + } + PLOG(ERROR) << "Fail to write into SocketId=" << id; + break; + } + } + bthread_usleep(1000 * 1000); + + ASSERT_EQ(0, s->SetFailed()); + s.release()->Dereference(); + pthread_join(rth, NULL); + ASSERT_EQ(REP, success_count); + ASSERT_EQ((brpc::Socket*)NULL, global_sock); + close(fds[0]); +} + +struct ShutdownWriterArg { + size_t times; + brpc::SocketId socket_id; + butil::atomic total_count; + butil::atomic success_count; +}; + +int HandleSocketShutdownWrite(bthread_id_t id, void* data, int error_code, + const std::string& error_text) { + auto arg = static_cast(data); + EXPECT_NE(nullptr, arg); + EXPECT_TRUE(0 == error_code || brpc::ESHUTDOWNWRITE == error_code) << error_code; + ++arg->total_count; + if (0 == error_code) { + ++arg->success_count; + } + CHECK_EQ(0, bthread_id_unlock_and_destroy(id)); + return 0; +} + +void* ShutdownWriter(void* void_arg) { + auto arg = static_cast(void_arg); + brpc::SocketUniquePtr sock; + if (brpc::Socket::Address(arg->socket_id, &sock) < 0) { + LOG(INFO) << "Fail to address SocketId=" << arg->socket_id; + return NULL; + } + for (size_t c = 0; c < arg->times; ++c) { + bthread_id_t write_id; + EXPECT_EQ(0, bthread_id_create2(&write_id, arg, + HandleSocketShutdownWrite)); + brpc::Socket::WriteOptions wopt; + wopt.id_wait = write_id; + wopt.notify_on_success = true; + wopt.shutdown_write = true; + butil::IOBuf src; + src.push_back('a'); + if (sock->Write(&src, &wopt) != 0) { + if (errno == brpc::EOVERCROWDED) { + // The buf is full, sleep a while and retry. + bthread_usleep(1000); + --c; + continue; + } + } + } + return NULL; +} + +void TestShutdownWrite() { + const size_t REP = 100; + int fds[2]; + ASSERT_EQ(0, socketpair(AF_UNIX, SOCK_STREAM, 0, fds)); + + brpc::SocketId id = 8888; + butil::EndPoint dummy; + ASSERT_EQ(0, str2endpoint("192.168.1.26:8080", &dummy)); + brpc::SocketOptions options; + options.fd = fds[1]; + options.remote_side = dummy; + options.user = new CheckRecycle; + ASSERT_EQ(0, brpc::Socket::Create(options, &id)); + brpc::SocketUniquePtr s; + ASSERT_EQ(0, brpc::Socket::Address(id, &s)); + s->_ssl_state = brpc::SSL_OFF; + ASSERT_EQ(2, brpc::NRefOfVRef(s->_versioned_ref)); + global_sock = s.get(); + ASSERT_TRUE(s.get()); + ASSERT_EQ(fds[1], s->fd()); + ASSERT_EQ(dummy, s->remote_side()); + ASSERT_EQ(id, s->id()); + ASSERT_FALSE(s->IsWriteShutdown()); + + pthread_t rth; + ReaderArg reader_arg = { fds[0], 0 }; + pthread_create(&rth, NULL, reader, &reader_arg); + + bthread_t th[3]; + ShutdownWriterArg args[ARRAY_SIZE(th)]; + for (size_t i = 0; i < ARRAY_SIZE(th); ++i) { + args[i].times = REP; + args[i].socket_id = id; + args[i].total_count = 0; + args[i].success_count = 0; + bthread_start_background(&th[i], NULL, ShutdownWriter, &args[i]); + } + + for (size_t i = 0; i < ARRAY_SIZE(th); ++i) { + ASSERT_EQ(0, bthread_join(th[i], NULL)); + } + bthread_usleep(50 * 1000); + + ASSERT_TRUE(s->IsWriteShutdown()); + ASSERT_FALSE(s->Failed()); + ASSERT_EQ(0, s->SetFailed()); + s.release()->Dereference(); + pthread_join(rth, NULL); + ASSERT_EQ((brpc::Socket*)NULL, global_sock); + close(fds[0]); + + size_t total_count = 0; + size_t success_count = 0; + for (auto & arg : args) { + total_count += arg.total_count; + success_count += arg.success_count; + } + ASSERT_EQ(REP * ARRAY_SIZE(th), total_count); + EXPECT_EQ((size_t)1, reader_arg.nread); + EXPECT_EQ((size_t)1, success_count); +} + +TEST_F(SocketTest, shutdown_write) { + for (int i = 0; i < 100; ++i) { + TestShutdownWrite(); + } +} + +TEST_F(SocketTest, packed_ptr) { + brpc::PackedPtr ptr; + ASSERT_EQ(nullptr, ptr.get()); + ASSERT_EQ(0, ptr.extra()); + + int a = 1; + uint16_t b = 2; + ptr.set(&a); + ASSERT_EQ(&a, ptr.get()); + *ptr.get() = b; + ASSERT_EQ(a, b); + ptr.set_extra(b); + ASSERT_EQ(b, ptr.extra()); + ptr.reset(); + ptr.reset_extra(); + ASSERT_EQ(nullptr, ptr.get()); + ASSERT_EQ(0, ptr.extra()); + + int c = 3; + uint16_t d = 4; + ptr.set_ptr_and_extra(&c, d); + ASSERT_EQ(&c, ptr.get()); + ASSERT_EQ(d, ptr.extra()); + ptr.reset_ptr_and_extra(); + ASSERT_EQ(nullptr, ptr.get()); + ASSERT_EQ(0, ptr.extra()); +}