Skip to content

Commit

Permalink
core: Correctly close sockets (#2357)
Browse files Browse the repository at this point in the history
* core: Correctly close sockets

* Update socket_holder.cpp - Fix newline style

* Update socket_holder.cpp

* Update src/mavsdk/core/tcp_client_connection.cpp

Co-authored-by: Jonas Vautherin <accounts@jonas.vautherin.ch>

* Update socket_holder.cpp - fix code style

* Update tcp_client_connection.cpp - fix code style

* Update tcp_server_connection.cpp - fix code style

* Update socket_holder.h - fix code style

* Update src/mavsdk/core/socket_holder.h

Co-authored-by: Julian Oes <julian@oes.ch>

* Update socket_holder.h - use 64 bit descriptor type on Win64

* Update socket_holder.cpp - use 64 bit descriptor type on Win64

* Update socket_holder.cpp - minor improving of if logic

* Remove default move constructor

It is currently not in use, and the default implementation is not suitable because it does not change _fd to INVALID for the object being copied from

---------

Co-authored-by: Jonas Vautherin <accounts@jonas.vautherin.ch>
Co-authored-by: Julian Oes <julian@oes.ch>
  • Loading branch information
3 people authored Jul 25, 2024
1 parent da070c5 commit da13d0a
Show file tree
Hide file tree
Showing 9 changed files with 140 additions and 79 deletions.
1 change: 1 addition & 0 deletions src/mavsdk/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ target_sources(mavsdk
server_component.cpp
server_component_impl.cpp
server_plugin_impl_base.cpp
socket_holder.cpp
tcp_client_connection.cpp
tcp_server_connection.cpp
timeout_handler.cpp
Expand Down
53 changes: 53 additions & 0 deletions src/mavsdk/core/socket_holder.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#include "socket_holder.h"

#ifndef WINDOWS
#include <sys/socket.h>
#include <unistd.h>
#endif

namespace mavsdk {

SocketHolder::SocketHolder(DescriptorType fd) noexcept : _fd{fd} {}

SocketHolder::~SocketHolder() noexcept
{
close();
}

void SocketHolder::reset(DescriptorType fd) noexcept
{
if (_fd != fd) {
close();
_fd = fd;
}
}

void SocketHolder::close() noexcept
{
if (!empty()) {
#if defined(WINDOWS)
shutdown(_fd, SD_BOTH);
closesocket(_fd);
WSACleanup();
#else
// This should interrupt a recv/recvfrom call.
shutdown(_fd, SHUT_RDWR);

// But on Mac, closing is also needed to stop blocking recv/recvfrom.
::close(_fd);
#endif
_fd = invalid_socket_fd;
}
}

bool SocketHolder::empty() const noexcept
{
return _fd == invalid_socket_fd;
}

SocketHolder::DescriptorType SocketHolder::get() const noexcept
{
return _fd;
}

} // namespace mavsdk
37 changes: 37 additions & 0 deletions src/mavsdk/core/socket_holder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#pragma once

#if defined(WINDOWS)
#include <winsock2.h>
#endif

namespace mavsdk {

class SocketHolder {
public:
#if defined(WINDOWS)
using DescriptorType = SOCKET;
static constexpr DescriptorType invalid_socket_fd = INVALID_SOCKET;
#else
using DescriptorType = int;
static constexpr DescriptorType invalid_socket_fd = -1;
#endif

SocketHolder() noexcept = default;
explicit SocketHolder(DescriptorType socket_fd) noexcept;

~SocketHolder() noexcept;

void reset(DescriptorType fd) noexcept;
void close() noexcept;

bool empty() const noexcept;
DescriptorType get() const noexcept;

private:
SocketHolder(const SocketHolder&) = delete;
SocketHolder& operator=(const SocketHolder&) = delete;

DescriptorType _fd = invalid_socket_fd;
};

} // namespace mavsdk
31 changes: 10 additions & 21 deletions src/mavsdk/core/tcp_client_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#include <arpa/inet.h>
#include <errno.h>
#include <netdb.h>
#include <unistd.h> // for close()
#endif

#ifndef WINDOWS
Expand Down Expand Up @@ -71,9 +70,9 @@ ConnectionResult TcpClientConnection::setup_port()
}
#endif

_socket_fd = socket(AF_INET, SOCK_STREAM, 0);
_socket_fd.reset(socket(AF_INET, SOCK_STREAM, 0));

if (_socket_fd < 0) {
if (_socket_fd.empty()) {
LogErr() << "socket error" << GET_ERROR(errno);
_is_ok = false;
return ConnectionResult::SocketError;
Expand All @@ -93,8 +92,10 @@ ConnectionResult TcpClientConnection::setup_port()

memcpy(&remote_addr.sin_addr, hp->h_addr, hp->h_length);

if (connect(_socket_fd, reinterpret_cast<sockaddr*>(&remote_addr), sizeof(struct sockaddr_in)) <
0) {
if (connect(
_socket_fd.get(),
reinterpret_cast<sockaddr*>(&remote_addr),
sizeof(struct sockaddr_in)) < 0) {
LogErr() << "connect error: " << GET_ERROR(errno);
_is_ok = false;
return ConnectionResult::SocketConnectionError;
Expand All @@ -113,19 +114,7 @@ ConnectionResult TcpClientConnection::stop()
{
_should_exit = true;

#ifndef WINDOWS
// This should interrupt a recv/recvfrom call.
shutdown(_socket_fd, SHUT_RDWR);

// But on Mac, closing is also needed to stop blocking recv/recvfrom.
close(_socket_fd);
#else
shutdown(_socket_fd, SD_BOTH);

closesocket(_socket_fd);

WSACleanup();
#endif
_socket_fd.close();

if (_recv_thread) {
_recv_thread->join();
Expand Down Expand Up @@ -175,7 +164,7 @@ bool TcpClientConnection::send_message(const mavlink_message_t& message)
#endif

const auto send_len = sendto(
_socket_fd,
_socket_fd.get(),
reinterpret_cast<char*>(buffer),
buffer_len,
flags,
Expand All @@ -202,7 +191,7 @@ void TcpClientConnection::receive()
setup_port();
}

const auto recv_len = recv(_socket_fd, buffer, sizeof(buffer), 0);
const auto recv_len = recv(_socket_fd.get(), buffer, sizeof(buffer), 0);

if (recv_len == 0) {
// This can happen when shutdown is called on the socket,
Expand All @@ -212,7 +201,7 @@ void TcpClientConnection::receive()
}

if (recv_len < 0) {
// This happens on desctruction when close(_socket_fd) is called,
// This happens on destruction when close(_socket_fd.get()) is called,
// therefore be quiet.
// LogErr() << "recvfrom error: " << GET_ERROR(errno);
// Something went wrong, we should try to re-connect in next iteration.
Expand Down
4 changes: 3 additions & 1 deletion src/mavsdk/core/tcp_client_connection.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once

#include "socket_holder.h"

#include <atomic>
#include <mutex>
#include <memory>
Expand Down Expand Up @@ -43,7 +45,7 @@ class TcpClientConnection : public Connection {
int _remote_port_number;

std::mutex _mutex = {};
int _socket_fd = -1;
SocketHolder _socket_fd;

std::unique_ptr<std::thread> _recv_thread{};
std::atomic_bool _should_exit;
Expand Down
50 changes: 23 additions & 27 deletions src/mavsdk/core/tcp_server_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include <arpa/inet.h>
#include <errno.h>
#include <netdb.h>
#include <unistd.h> // for close()
#endif

#ifndef WINDOWS
Expand Down Expand Up @@ -57,8 +56,8 @@ ConnectionResult TcpServerConnection::start()
}
#endif

_server_socket_fd = socket(AF_INET, SOCK_STREAM, 0);
if (_server_socket_fd < 0) {
_server_socket_fd.reset(socket(AF_INET, SOCK_STREAM, 0));
if (_server_socket_fd.empty()) {
LogErr() << "socket error: " << GET_ERROR(errno);
return ConnectionResult::SocketError;
}
Expand All @@ -68,13 +67,15 @@ ConnectionResult TcpServerConnection::start()
server_addr.sin_addr.s_addr = INADDR_ANY;
server_addr.sin_port = htons(_local_port);

if (bind(_server_socket_fd, reinterpret_cast<sockaddr*>(&server_addr), sizeof(server_addr)) <
0) {
if (bind(
_server_socket_fd.get(),
reinterpret_cast<sockaddr*>(&server_addr),
sizeof(server_addr)) < 0) {
LogErr() << "bind error: " << GET_ERROR(errno);
return ConnectionResult::SocketError;
}

if (listen(_server_socket_fd, 3) < 0) {
if (listen(_server_socket_fd.get(), 3) < 0) {
LogErr() << "listen error: " << GET_ERROR(errno);
return ConnectionResult::SocketError;
}
Expand All @@ -89,16 +90,8 @@ ConnectionResult TcpServerConnection::stop()
{
_should_exit = true;

#ifndef WINDOWS
shutdown(_client_socket_fd, SHUT_RDWR);
close(_client_socket_fd);
close(_server_socket_fd);
#else
shutdown(_client_socket_fd, SD_BOTH);
closesocket(_client_socket_fd);
closesocket(_server_socket_fd);
WSACleanup();
#endif
_client_socket_fd.close();
_server_socket_fd.close();

if (_accept_receive_thread && _accept_receive_thread->joinable()) {
_accept_receive_thread->join();
Expand Down Expand Up @@ -126,7 +119,7 @@ bool TcpServerConnection::send_message(const mavlink_message_t& message)
#endif

const auto send_len =
send(_client_socket_fd, reinterpret_cast<const char*>(buffer), buffer_len, flags);
send(_client_socket_fd.get(), reinterpret_cast<const char*>(buffer), buffer_len, flags);

if (send_len != buffer_len) {
LogErr() << "send failure: " << GET_ERROR(errno);
Expand All @@ -140,27 +133,28 @@ void TcpServerConnection::accept_client()
#ifdef WINDOWS
// Set server socket to non-blocking
u_long iMode = 1;
int iResult = ioctlsocket(_server_socket_fd, FIONBIO, &iMode);
int iResult = ioctlsocket(_server_socket_fd.get(), FIONBIO, &iMode);
if (iResult != 0) {
LogErr() << "ioctlsocket failed with error: " << WSAGetLastError();
}
#else
// Set server socket to non-blocking
int flags = fcntl(_server_socket_fd, F_GETFL, 0);
fcntl(_server_socket_fd, F_SETFL, flags | O_NONBLOCK);
int flags = fcntl(_server_socket_fd.get(), F_GETFL, 0);
fcntl(_server_socket_fd.get(), F_SETFL, flags | O_NONBLOCK);
#endif

while (!_should_exit) {
fd_set readfds;
FD_ZERO(&readfds);
FD_SET(_server_socket_fd, &readfds);
FD_SET(_server_socket_fd.get(), &readfds);

// Set timeout to 1 second
timeval timeout;
timeout.tv_sec = 1;
timeout.tv_usec = 0;

const int activity = select(_server_socket_fd + 1, &readfds, nullptr, nullptr, &timeout);
const int activity =
select(_server_socket_fd.get() + 1, &readfds, nullptr, nullptr, &timeout);

if (activity < 0 && errno != EINTR) {
LogErr() << "select error: " << GET_ERROR(errno);
Expand All @@ -172,13 +166,15 @@ void TcpServerConnection::accept_client()
continue;
}

if (FD_ISSET(_server_socket_fd, &readfds)) {
if (FD_ISSET(_server_socket_fd.get(), &readfds)) {
sockaddr_in client_addr{};
socklen_t client_addr_len = sizeof(client_addr);

_client_socket_fd = accept(
_server_socket_fd, reinterpret_cast<sockaddr*>(&client_addr), &client_addr_len);
if (_client_socket_fd < 0) {
_client_socket_fd.reset(accept(
_server_socket_fd.get(),
reinterpret_cast<sockaddr*>(&client_addr),
&client_addr_len));
if (_client_socket_fd.empty()) {
if (_should_exit) {
return;
}
Expand All @@ -197,7 +193,7 @@ void TcpServerConnection::receive()

bool dataReceived = false;
while (!dataReceived && !_should_exit) {
const auto recv_len = recv(_client_socket_fd, buffer.data(), buffer.size(), 0);
const auto recv_len = recv(_client_socket_fd.get(), buffer.data(), buffer.size(), 0);

#ifdef WINDOWS
if (recv_len == SOCKET_ERROR) {
Expand Down
5 changes: 3 additions & 2 deletions src/mavsdk/core/tcp_server_connection.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "connection.h"
#include "socket_holder.h"

#include <atomic>
#include <string>
Expand Down Expand Up @@ -28,8 +29,8 @@ class TcpServerConnection : public Connection {
Connection::ReceiverCallback _receiver_callback;
std::string _local_ip;
int _local_port;
int _server_socket_fd{-1};
int _client_socket_fd{-1};
SocketHolder _server_socket_fd;
SocketHolder _client_socket_fd;
std::unique_ptr<std::thread> _accept_receive_thread;
std::atomic<bool> _should_exit{false};
};
Expand Down
Loading

0 comments on commit da13d0a

Please sign in to comment.