Skip to content

Commit

Permalink
Support connect on socket create
Browse files Browse the repository at this point in the history
  • Loading branch information
chenBright committed Mar 17, 2024
1 parent 85b664b commit ff76e27
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 25 deletions.
54 changes: 36 additions & 18 deletions src/brpc/socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -746,12 +746,25 @@ int Socket::Create(const SocketOptions& options, SocketId* id) {
m->_keepalive_options = options.keepalive_options;
m->_bthread_tag = options.bthread_tag;
CHECK(NULL == m->_write_head.load(butil::memory_order_relaxed));
int fd = options.fd;
if (!m->ValidFileDescriptor(fd) && options.connect_on_create) {
// Connect on created.
fd = m->DoConnect(options.abstime, NULL, NULL);
if (fd < 0) {
PLOG(ERROR) << "Fail to connect to " << options.remote_side;
int error_code = errno != 0 ? errno : EHOSTDOWN;
m->SetFailed(error_code, "Fail to connect to %s: %s",
butil::endpoint2str(options.remote_side).c_str(),
berror(error_code));
return -1;
}
}
// Must be last one! Internal fields of this Socket may be access
// just after calling ResetFileDescriptor.
if (m->ResetFileDescriptor(options.fd) != 0) {
if (m->ResetFileDescriptor(fd) != 0) {
const int saved_errno = errno;
PLOG(ERROR) << "Fail to ResetFileDescriptor";
m->SetFailed(saved_errno, "Fail to ResetFileDescriptor: %s",
m->SetFailed(saved_errno, "Fail to ResetFileDescriptor: %s",
berror(saved_errno));
return -1;
}
Expand Down Expand Up @@ -1363,37 +1376,42 @@ int Socket::CheckConnected(int sockfd) {
return -1;
}

butil::EndPoint local_point;
CHECK_EQ(0, butil::get_local_side(sockfd, &local_point));
LOG_IF(INFO, FLAGS_log_connected)
<< "Connected to " << remote_side()
<< " via fd=" << (int)sockfd << " SocketId=" << id()
<< " local_side=" << local_point;
if (FLAGS_log_connected) {
butil::EndPoint local_point;
CHECK_EQ(0, butil::get_local_side(sockfd, &local_point));
LOG(INFO) << "Connected to " << remote_side()
<< " via fd=" << (int)sockfd << " SocketId=" << id()
<< " local_side=" << local_point;
}

if (CreatedByConnect()) {
g_vars->channel_conn << 1;
}
// Doing SSL handshake after TCP connected
return SSLHandshake(sockfd, false);
}

int Socket::DoConnect(const timespec* abstime,
int (*on_connect)(int, int, void*), void* data) {
if (_conn) {
return _conn->Connect(this, abstime, on_connect, data);
} else {
return Connect(abstime, on_connect, data);
}
}

int Socket::ConnectIfNot(const timespec* abstime, WriteRequest* req) {
if (_fd.load(butil::memory_order_consume) >= 0) {
return 0;
}
// Set tag for client side socket
// Set tag for client side socket.
_bthread_tag = bthread_self_tag();
// Have to hold a reference for `req'
// Have to hold a reference for `req'.
SocketUniquePtr s;
ReAddress(&s);
req->socket = s.get();
if (_conn) {
if (_conn->Connect(this, abstime, KeepWriteIfConnected, req) < 0) {
return -1;
}
} else {
if (Connect(abstime, KeepWriteIfConnected, req) < 0) {
return -1;
}
if (DoConnect(abstime, KeepWriteIfConnected, req) < 0) {
return -1;
}
s.release();
return 1;
Expand Down
11 changes: 11 additions & 0 deletions src/brpc/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,14 @@ struct SocketOptions {
// user->BeforeRecycle() before recycling.
int fd;
butil::EndPoint remote_side;
// If `connect_on_created' is true and `fd' is less than 0,
// a client connection will be established to remote_side()
// regarding deadline `abstime' when Socket is being created.
// Default: false, means that a connection will be established
// on first write.
bool connect_on_create;
// Default: NULL, means no timeout.
const timespec* abstime;
SocketUser* user;
// When *edge-triggered* events happen on the file descriptor, callback
// `on_edge_triggered_events' will be called. Inside the callback, user
Expand Down Expand Up @@ -640,8 +648,11 @@ friend void DereferenceSocket(Socket*);
// starting a connection request and `on_connect' will be called
// when connecting completes (whether it succeeds or not)
// Returns the socket fd on success, -1 otherwise
int DoConnect(const timespec* abstime,
int (*on_connect)(int fd, int err, void* data), void* data);
int Connect(const timespec* abstime,
int (*on_connect)(int fd, int err, void* data), void* data);

int CheckConnected(int sockfd);

// [Not thread-safe] Only used by `Write'.
Expand Down
2 changes: 2 additions & 0 deletions src/brpc/socket_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ BUTIL_FORCE_INLINE uint64_t MakeVRef(uint32_t version, int32_t nref) {

inline SocketOptions::SocketOptions()
: fd(-1)
, connect_on_create(false)
, abstime(NULL)
, user(NULL)
, on_edge_triggered_events(NULL)
, health_check_interval_s(-1)
Expand Down
102 changes: 95 additions & 7 deletions test/brpc_ssl_unittest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
#include <butil/macros.h>
#include <butil/fd_guard.h>
#include <butil/files/scoped_file.h>
#include <brpc/policy/baidu_rpc_meta.pb.h>
#include <brpc/policy/baidu_rpc_protocol.h>
#include <brpc/policy/most_common_message.h>
#include "brpc/global.h"
#include "brpc/socket.h"
#include "brpc/server.h"
Expand Down Expand Up @@ -54,11 +57,11 @@ const std::string EXP_RESPONSE = "world";
class EchoServiceImpl : public test::EchoService {
public:
EchoServiceImpl() : count(0) {}
virtual ~EchoServiceImpl() { g_delete = true; }
virtual void Echo(google::protobuf::RpcController* cntl_base,
const test::EchoRequest* request,
test::EchoResponse* response,
google::protobuf::Closure* done) {
~EchoServiceImpl() override { g_delete = true; }
void Echo(google::protobuf::RpcController* cntl_base,
const test::EchoRequest* request,
test::EchoResponse* response,
google::protobuf::Closure* done) override {
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = (brpc::Controller*)cntl_base;
count.fetch_add(1, butil::memory_order_relaxed);
Expand Down Expand Up @@ -207,7 +210,7 @@ TEST_F(SSLTest, force_ssl) {
test::EchoService_Stub stub(&channel);
test::EchoResponse res;
stub.Echo(&cntl, &req, &res, NULL);
EXPECT_EQ(EXP_RESPONSE, res.message()) << cntl.ErrorText();
ASSERT_EQ(EXP_RESPONSE, res.message()) << cntl.ErrorText();
}

{
Expand All @@ -218,13 +221,98 @@ TEST_F(SSLTest, force_ssl) {
test::EchoService_Stub stub(&channel);
test::EchoResponse res;
stub.Echo(&cntl, &req, &res, NULL);
EXPECT_TRUE(cntl.Failed());
ASSERT_TRUE(cntl.Failed());
}

ASSERT_EQ(0, server.Stop(0));
ASSERT_EQ(0, server.Join());
}

void ProcessResponse(brpc::InputMessageBase* msg_base) {
brpc::DestroyingPtr<brpc::policy::MostCommonMessage> msg(
static_cast<brpc::policy::MostCommonMessage*>(msg_base));
brpc::policy::RpcMeta meta;
ASSERT_TRUE(brpc::ParsePbFromIOBuf(&meta, msg->meta));
const brpc::policy::RpcResponseMeta &response_meta = meta.response();
ASSERT_EQ(0, response_meta.error_code()) << response_meta.error_text();

const brpc::CallId cid = { static_cast<uint64_t>(meta.correlation_id()) };
brpc::Controller* cntl = NULL;
ASSERT_EQ(0, bthread_id_lock(cid, (void**)&cntl));
ASSERT_NE(nullptr, cntl);
ASSERT_TRUE(brpc::ParsePbFromIOBuf(cntl->response(), msg->payload));
ASSERT_EQ(0, bthread_id_unlock_and_destroy(cid));
}

TEST_F(SSLTest, fd_ssl) {
brpc::Protocol dummy_protocol = {
brpc::policy::ParseRpcMessage, brpc::SerializeRequestDefault,
brpc::policy::PackRpcRequest,NULL, ProcessResponse,
NULL, NULL, NULL, brpc::CONNECTION_TYPE_ALL, "ssl_ut_baidu"
};
ASSERT_EQ(0, RegisterProtocol((brpc::ProtocolType)30, dummy_protocol));

brpc::InputMessageHandler dummy_handler ={
dummy_protocol.parse, dummy_protocol.process_response,
NULL, NULL, dummy_protocol.name
};
brpc::InputMessenger messenger;
ASSERT_EQ(0, messenger.AddHandler(dummy_handler));

const int port = 8613;
brpc::Server server;
brpc::ServerOptions server_options;
server_options.force_ssl = true;

brpc::CertInfo cert;
cert.certificate = "cert1.crt";
cert.private_key = "cert1.key";
server_options.mutable_ssl_options()->default_cert = cert;

EchoServiceImpl echo_svc;
ASSERT_EQ(0, server.AddService(
&echo_svc, brpc::SERVER_DOESNT_OWN_SERVICE));
ASSERT_EQ(0, server.Start(port, &server_options));

// Create client socket.
brpc::SocketOptions socket_options;
butil::EndPoint ep(butil::IP_ANY, port);
socket_options.remote_side = ep;
socket_options.connect_on_create = true;
socket_options.on_edge_triggered_events = brpc::InputMessenger::OnNewMessages;
socket_options.user = &messenger;
brpc::ChannelSSLOptions ssl_options;
SSL_CTX* raw_ctx = brpc::CreateClientSSLContext(ssl_options);
ASSERT_NE(nullptr, raw_ctx);
std::shared_ptr<brpc::SocketSSLContext> ssl_ctx
= std::make_shared<brpc::SocketSSLContext>();
ssl_ctx->raw_ctx = raw_ctx;
socket_options.initial_ssl_ctx = ssl_ctx;

brpc::SocketId socket_id;
ASSERT_EQ(0, brpc::Socket::Create(socket_options, &socket_id));
brpc::SocketUniquePtr ptr;
ASSERT_EQ(0, brpc::Socket::Address(socket_id, &ptr));

test::EchoRequest req;
req.set_message(EXP_REQUEST);
for (int i = 0; i < 100; ++i) {
test::EchoResponse res;
butil::IOBuf request_buf;
butil::IOBuf request_body;
brpc::Controller cntl;
cntl._response = &res;
const brpc::CallId correlation_id = cntl.call_id();
brpc::SerializeRequestDefault(&request_body, &cntl, &req);
brpc::policy::PackRpcRequest(&request_buf, NULL, correlation_id.value,
test::EchoService_Stub::descriptor()->method(0),
&cntl, request_body, NULL);
ASSERT_EQ(0, ptr->Write(&request_buf));
brpc::Join(correlation_id);
ASSERT_EQ(EXP_RESPONSE, res.message());
}
}

void CheckCert(const char* cname, const char* cert) {
const int port = 8613;
brpc::Channel channel;
Expand Down

0 comments on commit ff76e27

Please sign in to comment.