diff --git a/src/core/client/client.h b/src/core/client/client.h index f0d08330..3bc0c0fc 100644 --- a/src/core/client/client.h +++ b/src/core/client/client.h @@ -53,8 +53,7 @@ class SSFClient boost::asio::io_service& get_io_service(); private: - void NetworkToTransport(const boost::system::error_code& ec, - NetworkSocketPtr p_socket); + void NetworkToTransport(const boost::system::error_code& ec); void DoSSFStart(NetworkSocketPtr p_socket, const boost::system::error_code& ec); @@ -67,6 +66,8 @@ class SSFClient BaseUserServicePtr p_user_service, boost::system::error_code ec); private: + NetworkSocketPtr p_socket_; + AsyncEngine async_engine_; Demux fiber_demux_; diff --git a/src/core/client/client.ipp b/src/core/client/client.ipp index ae427c55..f6fcf1e2 100644 --- a/src/core/client/client.ipp +++ b/src/core/client/client.ipp @@ -30,6 +30,7 @@ SSFClient::SSFClient(std::vector user_services, ClientCallback callback) : T( boost::bind(&SSFClient::DoSSFStart, this, _1, _2)), + p_socket_(nullptr), async_engine_(), fiber_demux_(async_engine_.get_io_service()), user_services_(user_services), @@ -51,8 +52,7 @@ void SSFClient::Run(const NetworkQuery& query, } // Create network socket - NetworkSocketPtr p_socket = - std::make_shared(async_engine_.get_io_service()); + p_socket_ = std::make_shared(async_engine_.get_io_service()); // resolve remote endpoint with query NetworkResolver resolver(async_engine_.get_io_service()); @@ -66,13 +66,23 @@ void SSFClient::Run(const NetworkQuery& query, async_engine_.Start(); // async connect client to given endpoint - p_socket->async_connect( + p_socket_->async_connect( *endpoint_it, - boost::bind(&SSFClient::NetworkToTransport, this, _1, p_socket)); + boost::bind(&SSFClient::NetworkToTransport, this, _1)); } template class T> void SSFClient::Stop() { + if (!async_engine_.IsStarted()) { + return; + } + + if (p_socket_.get() != nullptr) { + boost::system::error_code close_ec; + p_socket_->close(close_ec); + p_socket_.reset(); + } + fiber_demux_.close(); async_engine_.Stop(); @@ -84,19 +94,18 @@ boost::asio::io_service& SSFClient::get_io_service() { } template class T> -void SSFClient::NetworkToTransport(const boost::system::error_code& ec, - NetworkSocketPtr p_socket) { +void SSFClient::NetworkToTransport(const boost::system::error_code& ec) { if (!ec) { - this->DoSSFInitiate(p_socket); + this->DoSSFInitiate(std::move(p_socket_)); return; } SSF_LOG(kLogError) << "client: error when connecting to server: " << ec.message(); - if (p_socket) { + if (p_socket_) { boost::system::error_code close_ec; - p_socket->close(close_ec); + p_socket_->close(close_ec); } Notify(ssf::services::initialisation::NETWORK, nullptr, ec); diff --git a/src/tests/network/CMakeLists.txt b/src/tests/network/CMakeLists.txt index 9f550c24..8e91a0fb 100644 --- a/src/tests/network/CMakeLists.txt +++ b/src/tests/network/CMakeLists.txt @@ -80,6 +80,27 @@ add_target("ssf_client_server_cipher_suites_tests" ) project_group(${network_tests_group_name} ssf_client_server_cipher_suites_tests) +# --- SSF Client tests +set(ssf_client_source_files + "ssf_client_tests.cpp" + ${SSF_SOURCES} +) + +add_target("ssf_client_tests" + TYPE + executable ${EXEC_FLAG} TEST + LINKS + ${OpenSSL_LIBRARIES} + ${Boost_LIBRARIES} + ${PLATFORM_SPECIFIC_LIB_DEP} + lib_ssf_network + PREFIX_SKIP .*/src + HEADER_FILTER "\\.h(h|m|pp|xx|\\+\\+)?" + FILES + ${ssf_client_source_files} +) +project_group(${network_tests_group_name} ssf_client_tests) + # --- SSF Server tests set(ssf_server_source_files "ssf_server_tests.cpp" diff --git a/src/tests/network/ssf_client_tests.cpp b/src/tests/network/ssf_client_tests.cpp new file mode 100644 index 00000000..d89b3547 --- /dev/null +++ b/src/tests/network/ssf_client_tests.cpp @@ -0,0 +1,117 @@ +#include +#include +#include + +#include +#include +#include + +#include + +#include "common/config/config.h" + +#include "core/network_protocol.h" +#include "core/transport_virtual_layer_policies/transport_protocol_policy.h" +#include "core/client/client.h" + +using NetworkProtocol = ssf::network::NetworkProtocol; +using Client = + ssf::SSFClient; +using Demux = Client::Demux; +using BaseUserServicePtr = + ssf::services::BaseUserService::BaseUserServicePtr; + +void InitTCPServer(boost::asio::ip::tcp::acceptor& server, int server_port); + +TEST(SSFClientTest, CloseWhileConnecting) { + ssf::AsyncEngine async_engine; + + std::condition_variable wait_stop_cv; + std::mutex mutex; + bool stopped = false; + + // End test callback + auto end_test = [&wait_stop_cv, &mutex, &stopped](bool status) { + { + boost::lock_guard lock(mutex); + stopped = status; + } + wait_stop_cv.notify_all(); + }; + + // Init server + int server_port = 15000; + boost::asio::ip::tcp::acceptor server(async_engine.get_io_service()); + InitTCPServer(server, 15000); + + // Init timer (if client hangs) + boost::system::error_code timer_ec; + boost::asio::steady_timer timer(async_engine.get_io_service()); + timer.expires_from_now(std::chrono::seconds(5), timer_ec); + ASSERT_EQ(0, timer_ec.value()); + timer.async_wait([&end_test](const boost::system::error_code& ec) { + EXPECT_NE(0, ec.value()) << "Timer should be canceled. Client is hanging"; + if (!ec) { + end_test(false); + } + }); + + // Init client + std::vector client_options; + ssf::config::Config ssf_config; + ssf_config.Init(); + + auto endpoint_query = NetworkProtocol::GenerateClientQuery( + "127.0.0.1", std::to_string(server_port), ssf_config, {}); + + auto callback = [&end_test](ssf::services::initialisation::type type, + BaseUserServicePtr p_user_service, + const boost::system::error_code& ec) { + EXPECT_EQ(ssf::services::initialisation::NETWORK, type); + EXPECT_NE(0, ec.value()); + end_test(ec.value() != 0); + }; + Client client(client_options, ssf_config.services(), callback); + + boost::system::error_code run_ec; + + // Connect client to server + client.Run(endpoint_query, run_ec); + ASSERT_EQ(0, run_ec.value()); + + // Wait new server connection + async_engine.Start(); + ASSERT_TRUE(async_engine.IsStarted()); + boost::asio::ip::tcp::socket socket(async_engine.get_io_service()); + server.async_accept(socket, [&client](const boost::system::error_code& ec) { + EXPECT_EQ(0, ec.value()) << "Accept connection in error"; + // Stop client while connecting + client.Stop(); + }); + + // Wait client action + std::unique_lock lock(mutex); + wait_stop_cv.wait(lock); + lock.unlock(); + + EXPECT_TRUE(stopped) << "Stop failed"; + + timer.cancel(timer_ec); + boost::system::error_code close_ec; + socket.close(close_ec); + server.close(close_ec); + + async_engine.Stop(); +} + +void InitTCPServer(boost::asio::ip::tcp::acceptor& server, int server_port) { + boost::system::error_code server_ec; + boost::asio::ip::tcp::endpoint server_ep(boost::asio::ip::tcp::v4(), + server_port); + server.open(boost::asio::ip::tcp::v4(), server_ec); + ASSERT_EQ(0, server_ec.value()) << "Could not open server acceptor"; + server.bind(server_ep, server_ec); + ASSERT_EQ(0, server_ec.value()) << "Could not bind server acceptor"; + server.listen(boost::asio::socket_base::max_connections, server_ec); + ASSERT_EQ(0, server_ec.value()) << "Server acceptor could not listen"; +}