Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor streams: rename is_* to wait_* for clarity #2069

Merged
merged 1 commit into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 37 additions & 27 deletions httplib.h
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,8 @@ class Stream {
virtual ~Stream() = default;

virtual bool is_readable() const = 0;
virtual bool is_writable() const = 0;
virtual bool wait_readable() const = 0;
virtual bool wait_writable() const = 0;

virtual ssize_t read(char *ptr, size_t size) = 0;
virtual ssize_t write(const char *ptr, size_t size) = 0;
Expand Down Expand Up @@ -2466,7 +2467,8 @@ class BufferStream final : public Stream {
~BufferStream() override = default;

bool is_readable() const override;
bool is_writable() const override;
bool wait_readable() const override;
bool wait_writable() const override;
ssize_t read(char *ptr, size_t size) override;
ssize_t write(const char *ptr, size_t size) override;
void get_remote_ip_and_port(std::string &ip, int &port) const override;
Expand Down Expand Up @@ -3380,7 +3382,8 @@ class SocketStream final : public Stream {
~SocketStream() override;

bool is_readable() const override;
bool is_writable() const override;
bool wait_readable() const override;
bool wait_writable() const override;
ssize_t read(char *ptr, size_t size) override;
ssize_t write(const char *ptr, size_t size) override;
void get_remote_ip_and_port(std::string &ip, int &port) const override;
Expand Down Expand Up @@ -3416,7 +3419,8 @@ class SSLSocketStream final : public Stream {
~SSLSocketStream() override;

bool is_readable() const override;
bool is_writable() const override;
bool wait_readable() const override;
bool wait_writable() const override;
ssize_t read(char *ptr, size_t size) override;
ssize_t write(const char *ptr, size_t size) override;
void get_remote_ip_and_port(std::string &ip, int &port) const override;
Expand Down Expand Up @@ -4578,7 +4582,7 @@ inline bool write_content(Stream &strm, const ContentProvider &content_provider,

data_sink.write = [&](const char *d, size_t l) -> bool {
if (ok) {
if (strm.is_writable() && write_data(strm, d, l)) {
if (write_data(strm, d, l)) {
offset += l;
} else {
ok = false;
Expand All @@ -4587,10 +4591,10 @@ inline bool write_content(Stream &strm, const ContentProvider &content_provider,
return ok;
};

data_sink.is_writable = [&]() -> bool { return strm.is_writable(); };
data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); };

while (offset < end_offset && !is_shutting_down()) {
if (!strm.is_writable()) {
if (!strm.wait_writable()) {
error = Error::Write;
return false;
} else if (!content_provider(offset, end_offset - offset, data_sink)) {
Expand Down Expand Up @@ -4628,17 +4632,17 @@ write_content_without_length(Stream &strm,
data_sink.write = [&](const char *d, size_t l) -> bool {
if (ok) {
offset += l;
if (!strm.is_writable() || !write_data(strm, d, l)) { ok = false; }
if (!write_data(strm, d, l)) { ok = false; }
}
return ok;
};

data_sink.is_writable = [&]() -> bool { return strm.is_writable(); };
data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); };

data_sink.done = [&](void) { data_available = false; };

while (data_available && !is_shutting_down()) {
if (!strm.is_writable()) {
if (!strm.wait_writable()) {
return false;
} else if (!content_provider(offset, 0, data_sink)) {
return false;
Expand Down Expand Up @@ -4673,10 +4677,7 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider,
// Emit chunked response header and footer for each chunk
auto chunk =
from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n";
if (!strm.is_writable() ||
!write_data(strm, chunk.data(), chunk.size())) {
ok = false;
}
if (!write_data(strm, chunk.data(), chunk.size())) { ok = false; }
}
} else {
ok = false;
Expand All @@ -4685,7 +4686,7 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider,
return ok;
};

data_sink.is_writable = [&]() -> bool { return strm.is_writable(); };
data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); };

auto done_with_trailer = [&](const Headers *trailer) {
if (!ok) { return; }
Expand All @@ -4705,8 +4706,7 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider,
if (!payload.empty()) {
// Emit chunked response header and footer for each chunk
auto chunk = from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n";
if (!strm.is_writable() ||
!write_data(strm, chunk.data(), chunk.size())) {
if (!write_data(strm, chunk.data(), chunk.size())) {
ok = false;
return;
}
Expand Down Expand Up @@ -4738,7 +4738,7 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider,
};

while (data_available && !is_shutting_down()) {
if (!strm.is_writable()) {
if (!strm.wait_writable()) {
error = Error::Write;
return false;
} else if (!content_provider(offset, 0, data_sink)) {
Expand Down Expand Up @@ -6029,6 +6029,10 @@ inline SocketStream::SocketStream(
inline SocketStream::~SocketStream() = default;

inline bool SocketStream::is_readable() const {
return read_buff_off_ < read_buff_content_size_;
}

inline bool SocketStream::wait_readable() const {
if (max_timeout_msec_ <= 0) {
return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0;
}
Expand All @@ -6041,7 +6045,7 @@ inline bool SocketStream::is_readable() const {
return select_read(sock_, read_timeout_sec, read_timeout_usec) > 0;
}

inline bool SocketStream::is_writable() const {
inline bool SocketStream::wait_writable() const {
return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 &&
is_socket_alive(sock_);
}
Expand All @@ -6068,7 +6072,7 @@ inline ssize_t SocketStream::read(char *ptr, size_t size) {
}
}

if (!is_readable()) { return -1; }
if (!wait_readable()) { return -1; }

read_buff_off_ = 0;
read_buff_content_size_ = 0;
Expand All @@ -6093,7 +6097,7 @@ inline ssize_t SocketStream::read(char *ptr, size_t size) {
}

inline ssize_t SocketStream::write(const char *ptr, size_t size) {
if (!is_writable()) { return -1; }
if (!wait_writable()) { return -1; }

#if defined(_WIN32) && !defined(_WIN64)
size =
Expand Down Expand Up @@ -6124,7 +6128,9 @@ inline time_t SocketStream::duration() const {
// Buffer stream implementation
inline bool BufferStream::is_readable() const { return true; }

inline bool BufferStream::is_writable() const { return true; }
inline bool BufferStream::wait_readable() const { return true; }

inline bool BufferStream::wait_writable() const { return true; }

inline ssize_t BufferStream::read(char *ptr, size_t size) {
#if defined(_MSC_VER) && _MSC_VER < 1910
Expand Down Expand Up @@ -9161,6 +9167,10 @@ inline SSLSocketStream::SSLSocketStream(
inline SSLSocketStream::~SSLSocketStream() = default;

inline bool SSLSocketStream::is_readable() const {
return SSL_pending(ssl_) > 0;
}

inline bool SSLSocketStream::wait_readable() const {
if (max_timeout_msec_ <= 0) {
return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0;
}
Expand All @@ -9173,15 +9183,15 @@ inline bool SSLSocketStream::is_readable() const {
return select_read(sock_, read_timeout_sec, read_timeout_usec) > 0;
}

inline bool SSLSocketStream::is_writable() const {
inline bool SSLSocketStream::wait_writable() const {
return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 &&
is_socket_alive(sock_) && !is_ssl_peer_could_be_closed(ssl_, sock_);
}

inline ssize_t SSLSocketStream::read(char *ptr, size_t size) {
if (SSL_pending(ssl_) > 0) {
return SSL_read(ssl_, ptr, static_cast<int>(size));
} else if (is_readable()) {
} else if (wait_readable()) {
auto ret = SSL_read(ssl_, ptr, static_cast<int>(size));
if (ret < 0) {
auto err = SSL_get_error(ssl_, ret);
Expand All @@ -9195,7 +9205,7 @@ inline ssize_t SSLSocketStream::read(char *ptr, size_t size) {
#endif
if (SSL_pending(ssl_) > 0) {
return SSL_read(ssl_, ptr, static_cast<int>(size));
} else if (is_readable()) {
} else if (wait_readable()) {
std::this_thread::sleep_for(std::chrono::microseconds{10});
ret = SSL_read(ssl_, ptr, static_cast<int>(size));
if (ret >= 0) { return ret; }
Expand All @@ -9212,7 +9222,7 @@ inline ssize_t SSLSocketStream::read(char *ptr, size_t size) {
}

inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) {
if (is_writable()) {
if (wait_writable()) {
auto handle_size = static_cast<int>(
std::min<size_t>(size, (std::numeric_limits<int>::max)()));

Expand All @@ -9227,7 +9237,7 @@ inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) {
#else
while (--n >= 0 && err == SSL_ERROR_WANT_WRITE) {
#endif
if (is_writable()) {
if (wait_writable()) {
std::this_thread::sleep_for(std::chrono::microseconds{10});
ret = SSL_write(ssl_, ptr, static_cast<int>(handle_size));
if (ret >= 0) { return ret; }
Expand Down
4 changes: 3 additions & 1 deletion test/fuzzing/server_fuzzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ class FuzzedStream : public httplib::Stream {

bool is_readable() const override { return true; }

bool is_writable() const override { return true; }
bool wait_readable() const override { return true; }

bool wait_writable() const override { return true; }

void get_remote_ip_and_port(std::string &ip, int &port) const override {
ip = "127.0.0.1";
Expand Down
10 changes: 5 additions & 5 deletions test/test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ TEST_F(UnixSocketTest, abstract) {
}
#endif

TEST(SocketStream, is_writable_UNIX) {
TEST(SocketStream, wait_writable_UNIX) {
int fds[2];
ASSERT_EQ(0, socketpair(AF_UNIX, SOCK_STREAM, 0, fds));

Expand All @@ -167,17 +167,17 @@ TEST(SocketStream, is_writable_UNIX) {
};
asSocketStream(fds[0], [&](Stream &s0) {
EXPECT_EQ(s0.socket(), fds[0]);
EXPECT_TRUE(s0.is_writable());
EXPECT_TRUE(s0.wait_writable());

EXPECT_EQ(0, close(fds[1]));
EXPECT_FALSE(s0.is_writable());
EXPECT_FALSE(s0.wait_writable());

return true;
});
EXPECT_EQ(0, close(fds[0]));
}

TEST(SocketStream, is_writable_INET) {
TEST(SocketStream, wait_writable_INET) {
sockaddr_in addr;
memset(&addr, 0, sizeof(addr));
addr.sin_family = AF_INET;
Expand Down Expand Up @@ -212,7 +212,7 @@ TEST(SocketStream, is_writable_INET) {
};
asSocketStream(disconnected_svr_sock, [&](Stream &ss) {
EXPECT_EQ(ss.socket(), disconnected_svr_sock);
EXPECT_FALSE(ss.is_writable());
EXPECT_FALSE(ss.wait_writable());

return true;
});
Expand Down
Loading