Skip to content

Commit

Permalink
[core] Fixed bug: srt_accept failure may make accepted socket leak (#…
Browse files Browse the repository at this point in the history
…1884).

Added unit test.
  • Loading branch information
ethouris authored Jun 12, 2024
1 parent ebe2c71 commit 72303d7
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 18 deletions.
43 changes: 26 additions & 17 deletions srtcore/api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,7 @@ int srt::CUDTUnited::newConnection(const SRTSOCKET listen,
enterCS(ls->m_AcceptLock);
try
{
ls->m_QueuedSockets.insert(ns->m_SocketID);
ls->m_QueuedSockets[ns->m_SocketID] = ns->m_PeerAddr;
}
catch (...)
{
Expand Down Expand Up @@ -1110,8 +1110,22 @@ SRTSOCKET srt::CUDTUnited::accept(const SRTSOCKET listen, sockaddr* pw_addr, int
}
else if (ls->m_QueuedSockets.size() > 0)
{
set<SRTSOCKET>::iterator b = ls->m_QueuedSockets.begin();
u = *b;
map<SRTSOCKET, sockaddr_any>::iterator b = ls->m_QueuedSockets.begin();

if (pw_addr != NULL && pw_addrlen != NULL)
{
// Check if the length of the buffer to fill the name in
// was large enough.
const int len = b->second.size();
if (*pw_addrlen < len)
{
// In case when the address cannot be rewritten,
// DO NOT accept, but leave the socket in the queue.
throw CUDTException(MJ_NOTSUP, MN_INVAL, 0);
}
}

u = b->first;
ls->m_QueuedSockets.erase(b);
accepted = true;
}
Expand Down Expand Up @@ -1182,14 +1196,8 @@ SRTSOCKET srt::CUDTUnited::accept(const SRTSOCKET listen, sockaddr* pw_addr, int

if (pw_addr != NULL && pw_addrlen != NULL)
{
// Check if the length of the buffer to fill the name in
// was large enough.
const int len = s->m_PeerAddr.size();
if (*pw_addrlen < len)
throw CUDTException(MJ_NOTSUP, MN_INVAL, 0);

memcpy((pw_addr), &s->m_PeerAddr, len);
*pw_addrlen = len;
memcpy((pw_addr), s->m_PeerAddr.get(), s->m_PeerAddr.size());
*pw_addrlen = s->m_PeerAddr.size();
}

return u;
Expand Down Expand Up @@ -2751,23 +2759,24 @@ void srt::CUDTUnited::removeSocket(const SRTSOCKET u)

// if it is a listener, close all un-accepted sockets in its queue
// and remove them later
for (set<SRTSOCKET>::iterator q = s->m_QueuedSockets.begin(); q != s->m_QueuedSockets.end(); ++q)
for (map<SRTSOCKET, sockaddr_any>::iterator q = s->m_QueuedSockets.begin();
q != s->m_QueuedSockets.end(); ++ q)
{
sockets_t::iterator si = m_Sockets.find(*q);
sockets_t::iterator si = m_Sockets.find(q->first);
if (si == m_Sockets.end())
{
// gone in the meantime
LOGC(smlog.Error,
log << "removeSocket: IPE? socket @" << (*q) << " being queued for listener socket @"
<< s->m_SocketID << " is GONE in the meantime ???");
log << "removeSocket: IPE? socket @" << (q->first) << " being queued for listener socket @"
<< s->m_SocketID << " is GONE in the meantime ???");
continue;
}

CUDTSocket* as = si->second;

as->breakSocket_LOCKED();
m_ClosedSockets[*q] = as;
m_Sockets.erase(*q);
m_ClosedSockets[q->first] = as;
m_Sockets.erase(q->first);
}
}

Expand Down
2 changes: 1 addition & 1 deletion srtcore/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class CUDTSocket
CUDT m_UDT; //< internal SRT socket logic

public:
std::set<SRTSOCKET> m_QueuedSockets; //< set of connections waiting for accept()
std::map<SRTSOCKET, sockaddr_any> m_QueuedSockets; //< set of connections waiting for accept()

sync::Condition m_AcceptCond; //< used to block "accept" call
sync::Mutex m_AcceptLock; //< mutex associated to m_AcceptCond
Expand Down
67 changes: 67 additions & 0 deletions test/test_connection_timeout.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <chrono>
#include <thread>
#include <gtest/gtest.h>
#include "test_env.h"

Expand All @@ -12,6 +13,7 @@ typedef int SOCKET;

#include"platform_sys.h"
#include "srt.h"
#include "netinet_any.h"

using namespace std;

Expand Down Expand Up @@ -204,3 +206,68 @@ TEST_F(TestConnectionTimeout, BlockingLoop)
}


TEST(TestConnectionAPI, Accept)
{
using namespace std::chrono;
using namespace srt;

srt_startup();

const SRTSOCKET caller_sock = srt_create_socket();
const SRTSOCKET listener_sock = srt_create_socket();

const int eidl = srt_epoll_create();
const int eidc = srt_epoll_create();
const int ev_conn = SRT_EPOLL_OUT | SRT_EPOLL_ERR;
srt_epoll_add_usock(eidc, caller_sock, &ev_conn);
const int ev_acp = SRT_EPOLL_IN | SRT_EPOLL_ERR;
srt_epoll_add_usock(eidl, listener_sock, &ev_acp);

sockaddr_any sa = srt::CreateAddr("localhost", 5555, AF_INET);

ASSERT_NE(srt_bind(listener_sock, sa.get(), sa.size()), -1);
ASSERT_NE(srt_listen(listener_sock, 1), -1);

// Set non-blocking mode so that you can wait for readiness
bool no = false;
srt_setsockflag(caller_sock, SRTO_RCVSYN, &no, sizeof no);
srt_setsockflag(listener_sock, SRTO_RCVSYN, &no, sizeof no);

srt_connect(caller_sock, sa.get(), sa.size());

SRT_EPOLL_EVENT ready[2];
int nready = srt_epoll_uwait(eidl, ready, 2, 1000); // Wait 1s
EXPECT_EQ(nready, 1);
EXPECT_EQ(ready[0].fd, listener_sock);
// EXPECT_EQ(ready[0].events, SRT_EPOLL_IN);

// Now call the accept function incorrectly
int size = 0;
sockaddr_storage saf;

EXPECT_EQ(srt_accept(listener_sock, (sockaddr*)&saf, &size), SRT_ERROR);

std::this_thread::sleep_for(seconds(1));

// Set correctly
size = sizeof (sockaddr_in6);
EXPECT_NE(srt_accept(listener_sock, (sockaddr*)&saf, &size), SRT_ERROR);

// Ended up with error, but now you should also expect error on the caller side.

// Wait 5s until you get a connection broken.
nready = srt_epoll_uwait(eidc, ready, 2, 5000);
EXPECT_EQ(nready, 1);
if (nready == 1)
{
// Do extra checks only if you know that this was returned.
EXPECT_EQ(ready[0].fd, caller_sock);
EXPECT_EQ(ready[0].events & SRT_EPOLL_ERR, 0);
}
srt_close(caller_sock);
srt_close(listener_sock);

srt_cleanup();
}


0 comments on commit 72303d7

Please sign in to comment.