Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flush socket on body limit #233

Merged
merged 4 commits into from
Apr 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 37 additions & 8 deletions src/helper/client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <chrono>
#include <mutex>
#include <spdlog/spdlog.h>
#include <stdexcept>
#include <string>
#include <thread>

Expand Down Expand Up @@ -52,11 +53,19 @@ bool maybe_exec_cmd_M(client &client, network::request &msg)
return false;
}

void send_error_response(const network::base_broker &broker)
bool send_error_response(const network::base_broker &broker)
{
if (!broker.send(std::make_shared<network::error::response>())) {
SPDLOG_WARN("Failed to send error response");
try {
if (!broker.send(std::make_shared<network::error::response>())) {
SPDLOG_WARN("Failed to send error response");
return false;
}
} catch (const std::exception &e) {
SPDLOG_WARN("Failed to send error response: {}", e.what());
return false;
}

return true;
}

template <typename... Ms>
Expand All @@ -73,33 +82,53 @@ bool handle_message(client &client, const network::base_broker &broker,
}

bool send_error = false;
bool result = true;
try {
auto msg = broker.recv(initial_timeout);
return maybe_exec_cmd_M<Ms...>(client, msg);
} catch (const client_disconnect &) {
SPDLOG_INFO("Client has disconnected");
// When this exception has been received, we should stop hadling this
// particular client.
result = false;
} catch (const std::out_of_range &e) {
// The message received was too large, in theory this should've been
// flushed and we can continue handling messages from this client,
// however we need to report an error to ensure the client is in a good
// state.
SPDLOG_WARN("Failed to handle message: {}", e.what());
Anilm3 marked this conversation as resolved.
Show resolved Hide resolved
send_error = true;
} catch (const std::length_error &e) {
// The message was partially received, the state of the socket is
// undefined so we need to respond with an error and stop handling
// this client.
SPDLOG_WARN("Failed to handle message: {}", e.what());
send_error = true;
result = false;
} catch (const bad_cast &e) {
// The data received was somehow incomprehensible but we might still be
// able to continue, so we only send an error.
SPDLOG_WARN("Failed to handle message: {}", e.what());
send_error = true;
} catch (const msgpack::unpack_error &e) {
// The data received was somehow incomprehensible or perhaps beyond
// limits, but we might still be able to continue, so we only send an
// error.
SPDLOG_WARN("Failed to unpack message: {}", e.what());
send_error = true;
} catch (const std::exception &e) {
SPDLOG_WARN("Failed to handle message: {}", e.what());
result = false;
}

if (send_error) {
// This can happen due to a valid error, let's continue handling
// the client as this might just happen spuriously.
send_error_response(broker);
return true;
if (!send_error_response(broker)) {
return false;
}
}

// If we reach this point, there was a problem handling the message
return false;
return result;
}

} // namespace
Expand Down
16 changes: 12 additions & 4 deletions src/helper/network/broker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <msgpack.hpp>
#include <spdlog/spdlog.h>
#include <sstream>
#include <stdexcept>

namespace {

Expand Down Expand Up @@ -39,22 +40,29 @@ request broker::recv(std::chrono::milliseconds initial_timeout) const
"Not enough data for header:" + std::to_string(res) + " bytes");
}

// TODO: remove or increase this dramatically with WAF 1.5.0
static msgpack::unpack_limit const limits(max_array_size, max_map_size,
max_string_length, max_binary_size, max_extension_size, max_depth);

msgpack::unpacker u(&default_reference_func, MSGPACK_NULLPTR,
MSGPACK_UNPACKER_INIT_BUFFER_SIZE, limits); // NOLINT

static constexpr auto timeout_msg_body{std::chrono::milliseconds{300}};
socket_->set_recv_timeout(timeout_msg_body);

if (h.size >= max_msg_body_size) {
throw std::length_error(
auto res = socket_->discard(h.size);
if (res < h.size) {
throw std::length_error(
"Message body too large: " + std::to_string(h.size) +
" but failed to flush");
}

throw std::out_of_range(
"Message body too large: " + std::to_string(h.size));
}
// Allocate a buffer of the message size
u.reserve_buffer(h.size);

static constexpr auto timeout_msg_body{std::chrono::milliseconds{300}};
socket_->set_recv_timeout(timeout_msg_body);
res = socket_->recv(u.buffer(), h.size);
if (res != h.size) {
throw std::length_error(
Expand Down
18 changes: 18 additions & 0 deletions src/helper/network/socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
// This product includes software developed at Datadog
// (https://www.datadoghq.com/). Copyright 2021 Datadog, Inc.
#include "socket.hpp"
#include <array>
#include <cerrno>
#include <chrono>
#include <spdlog/spdlog.h>
Expand Down Expand Up @@ -45,6 +46,23 @@ std::size_t socket::send(const char *buffer, std::size_t len)
return res;
}

std::size_t socket::discard(std::size_t len)
{
constexpr auto max_size = std::numeric_limits<uint16_t>::max();
std::array<char, max_size> buffer{};

std::size_t total_size = 0;
while (total_size < len) {
auto read_size = std::min<std::size_t>(len - total_size, max_size);
ssize_t const res = ::recv(sock_, buffer.data(), read_size, 0);
if (res <= 0) {
break;
}
total_size += res;
}
return total_size;
}

namespace {
struct timeval from_chrono(std::chrono::milliseconds duration)
{
Expand Down
2 changes: 2 additions & 0 deletions src/helper/network/socket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class base_socket {

virtual std::size_t recv(char *buffer, std::size_t len) = 0;
virtual std::size_t send(const char *buffer, std::size_t len) = 0;
virtual std::size_t discard(std::size_t len) = 0;

virtual void set_send_timeout(std::chrono::milliseconds timeout) = 0;
virtual void set_recv_timeout(std::chrono::milliseconds timeout) = 0;
Expand Down Expand Up @@ -58,6 +59,7 @@ class socket : public base_socket {

std::size_t recv(char *buffer, std::size_t len) override;
std::size_t send(const char *buffer, std::size_t len) override;
std::size_t discard(std::size_t len) override;

void set_send_timeout(std::chrono::milliseconds timeout) override;
void set_recv_timeout(std::chrono::milliseconds timeout) override;
Expand Down
5 changes: 5 additions & 0 deletions tests/fuzzer/network.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ class raw_socket : public network::base_socket {
return len;
}

std::size_t discard(std::size_t len) override
{
return r.read_bytes(nullptr, len);
}

void set_send_timeout(std::chrono::milliseconds timeout) override {}
void set_recv_timeout(std::chrono::milliseconds timeout) override {}
protected:
Expand Down
31 changes: 31 additions & 0 deletions tests/helper/broker_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <network/broker.hpp>
#include <network/socket.hpp>
#include <parameter_view.hpp>
#include <stdexcept>

namespace dds {

Expand All @@ -20,6 +21,7 @@ class socket : public network::base_socket {
~socket() override = default;
MOCK_METHOD2(recv, std::size_t(char *, std::size_t));
MOCK_METHOD2(send, std::size_t(const char *, std::size_t));
MOCK_METHOD1(discard, std::size_t(std::size_t));

void set_send_timeout(std::chrono::milliseconds timeout) override {}
void set_recv_timeout(std::chrono::milliseconds timeout) override {}
Expand Down Expand Up @@ -616,6 +618,35 @@ TEST(BrokerTest, ParsingBodyLimit)
network::header_t h{"dds", (uint32_t)expected_data.size()};
EXPECT_CALL(*socket, recv(_, _))
.WillOnce(DoAll(CopyHeader(&h), Return(sizeof(network::header_t))));
EXPECT_CALL(*socket, discard(h.size)).WillOnce(Return(h.size));

network::request request;
EXPECT_THROW(request = broker.recv(std::chrono::milliseconds(100)),
std::out_of_range);
}

TEST(BrokerTest, ParsingBodyLimitFailFlush)
{
mock::socket *socket = new mock::socket();
network::broker broker{std::unique_ptr<mock::socket>(socket)};

std::stringstream ss;
msgpack::packer<std::stringstream> packer(ss);
packer.pack_array(1);
pack_str(packer, "request_shutdown");
packer.pack_array(1);
packer.pack_map(16);
for (char c = 'a'; c < 'q'; c++) {
pack_str(packer, std::string(4, c));
pack_str(packer, std::string(4096, c));
}

const std::string &expected_data = ss.str();

network::header_t h{"dds", (uint32_t)expected_data.size()};
EXPECT_CALL(*socket, recv(_, _))
.WillOnce(DoAll(CopyHeader(&h), Return(sizeof(network::header_t))));
EXPECT_CALL(*socket, discard(_)).WillOnce(Return(h.size - 1));

network::request request;
EXPECT_THROW(request = broker.recv(std::chrono::milliseconds(100)),
Expand Down