Skip to content

Commit

Permalink
mpsc: Added lightweight notification
Browse files Browse the repository at this point in the history
  • Loading branch information
shawnfeng0 committed Nov 17, 2024
1 parent 424ff3f commit 123dbf8
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 37 deletions.
46 changes: 46 additions & 0 deletions include/ulog/helper/queue/lite_notifier.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
//
// Created by shawn on 24-11-17.
//

#pragma once
#include <atomic>
#include <condition_variable>
#include <mutex>

namespace ulog {

/**
* @brief In conjunction with a lock-free queue, this notifier should be used when the queue is full or empty, as
* retries can affect performance.
*/
class LiteNotifier {
public:
template <typename Predicate>
bool wait_for(std::chrono::milliseconds timeout, Predicate p) {
if (p()) return true;

std::unique_lock<std::mutex> lk(mtx_);

auto ret = cv_.wait_for(lk, timeout, [&] {
signal_needed = true;
return p();
});

signal_needed = false;
return ret;
}

void notify_when_blocking() {
if (signal_needed.exchange(false)) {
std::unique_lock<std::mutex> lk(mtx_);
cv_.notify_all();
}
}

private:
std::mutex mtx_;
std::condition_variable cv_;
std::atomic_bool signal_needed{false};
};

} // namespace ulog
93 changes: 63 additions & 30 deletions include/ulog/helper/queue/mpsc_ring.h
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
#pragma once

#include <inttypes.h>
#include <unistd.h>

#include <atomic>
#include <cinttypes>
#include <cstddef>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <memory>

#include "lite_notifier.h"

// The basic principle of circular queue implementation:
// A - B is the position of A relative to B

Expand Down Expand Up @@ -90,27 +92,29 @@ class Umq {

~Umq() = default;

void Debug() {
const auto in = prod_tail_.load(std::memory_order_acquire);
const auto last = prod_last_.load(std::memory_order_relaxed);
const auto out = cons_head_.load(std::memory_order_relaxed);
void Debug(const bool all = false) {
if (all) {
const auto in = prod_tail_.load(std::memory_order_acquire);
const auto last = prod_last_.load(std::memory_order_relaxed);
const auto out = cons_head_.load(std::memory_order_relaxed);

const auto cur_in = in & mask();
const auto cur_last = last & mask();
const auto cur_out = out & mask();
const auto cur_in = in & mask();
const auto cur_last = last & mask();
const auto cur_out = out & mask();

printf("in: %zd(%zd), last: %zd(%zd), out: %zd(%zd)\n", in, cur_in, last, cur_last, out, cur_out);
for (size_t i = 0; i < buffer_size; i++) {
printf("|%02x", data_[i]);
printf("in: %zd(%zd), last: %zd(%zd), out: %zd(%zd)\n", in, cur_in, last, cur_last, out, cur_out);
for (size_t i = 0; i < buffer_size; i++) {
printf("|%02x", data_[i]);
}
printf("|\n");

for (size_t i = 0; i < buffer_size; i++) printf("%s", i == cur_in ? "^in" : " ");
printf("\n");
for (size_t i = 0; i < buffer_size; i++) printf("%s", i == cur_last ? "^last" : " ");
printf("\n");
for (size_t i = 0; i < buffer_size; i++) printf("%s", i == cur_out ? "^out" : " ");
printf("\n");
}
printf("|\n");

for (size_t i = 0; i < buffer_size; i++) printf("%s", i == cur_in ? "^in" : " ");
printf("\n");
for (size_t i = 0; i < buffer_size; i++) printf("%s", i == cur_last ? "^last" : " ");
printf("\n");
for (size_t i = 0; i < buffer_size; i++) printf("%s", i == cur_out ? "^out" : " ");
printf("\n");
}

private:
Expand All @@ -128,6 +132,9 @@ class Umq {

std::atomic<size_t> prod_tail_;
std::atomic<size_t> prod_last_;

LiteNotifier prod_notifier_;
LiteNotifier cons_notifier_;
};

template <int buffer_size>
Expand All @@ -138,12 +145,24 @@ class Producer {
~Producer() = default;

/**
* Try to reserve space of size size
* Reserve space of size, automatically retry until timeout
* @param size size of space to reserve
* @param timeout_ms The maximum waiting time if there is insufficient space in the queue
* @return data pointer if successful, otherwise nullptr
*/
void *ReserveOrWait(const size_t size, const uint32_t timeout_ms) {
void *ptr;
ring_->cons_notifier_.wait_for(std::chrono::milliseconds(timeout_ms),
[&] { return (ptr = Reserve(size)) != nullptr; });
return ptr;
}

/**
* Try to reserve space of size
* @param size size of space to reserve
* @param retry_times The number of retries when competing with other producers
* @return data pointer if successful, otherwise nullptr
*/
void *TryReserve(const size_t size, size_t retry_times = 128) {
void *Reserve(const size_t size) {
const auto packet_size = sizeof(Packet) + align8(size);

auto in = ring_->prod_tail_.load(std::memory_order_relaxed);
Expand Down Expand Up @@ -192,7 +211,7 @@ class Producer {
}
// Neither the end of the current range nor the head of the next range is enough
return nullptr;
} while (retry_times--);
} while (true);

pending_packet_->magic = 0xefbeefbe;
pending_packet_->set_size(size);
Expand All @@ -203,10 +222,11 @@ class Producer {
* Commits the data to the buffer, so that it can be read out.
*/
void Commit() {
if (pending_packet_) {
pending_packet_->mark_submitted();
pending_packet_ = nullptr;
}
if (!pending_packet_) return;

pending_packet_->mark_submitted();
pending_packet_ = nullptr;
ring_->prod_notifier_.notify_when_blocking();
}

private:
Expand All @@ -223,8 +243,21 @@ class Consumer {
void Debug() { ring_->Debug(); }

/**
* Gets a pointer to the contiguous block in the buffer, and returns the size
* of that block.
* Gets a pointer to the contiguous block in the buffer, and returns the size of that block. automatically retry until
* timeout
* @param out_size returns the size of the contiguous block
* @param timeout_ms The maximum waiting time
* @return pointer to the contiguous block
*/
void *ReadOrWait(uint32_t *out_size, const uint32_t timeout_ms) {
void *ptr;
ring_->prod_notifier_.wait_for(std::chrono::milliseconds(timeout_ms),
[&] { return (ptr = TryReadOnePacket(out_size)) != nullptr; });
return ptr;
}

/**
* Gets a pointer to the contiguous block in the buffer, and returns the size of that block.
* @param out_size returns the size of the contiguous block
* @return pointer to the contiguous block
*/
Expand Down Expand Up @@ -282,7 +315,6 @@ class Consumer {
*
* The validity of the size is not checked, it needs to be within the range
* returned by the TryReadOnePacket function.
*
*/
void ReleasePacket() { ReleasePacket(reading_packet_); }

Expand All @@ -299,6 +331,7 @@ class Consumer {
} else {
ring_->cons_head_.store(out + packet_size, std::memory_order_release);
}
ring_->cons_notifier_.notify_when_blocking();
}

void *CheckPacket(void *ptr, uint32_t *out_size) {
Expand Down
15 changes: 8 additions & 7 deletions tests/mpsc_ring_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
#include <random>
#include <thread>

#include "ulog/helper/queue/lite_notifier.h"
#include "ulog/ulog.h"

template <int buffer_size>
static void spsc(const size_t max_write_thread = 4) {
const uint64_t limit = buffer_size * 8192;
ulog::umq::Umq<buffer_size> buffer;
std::atomic_uint64_t write_count{0};

std::atomic_uint64_t write_count{0};
auto write_entry = [&] {
ulog::umq::Producer<buffer_size> producer(&buffer);
std::random_device rd;
Expand All @@ -27,7 +28,7 @@ static void spsc(const size_t max_write_thread = 4) {
if (write_count > limit) break;

size_t size = dis(gen);
const auto data = static_cast<uint8_t*>(producer.TryReserve(size));
const auto data = static_cast<uint8_t*>(producer.ReserveOrWait(size, 100));

if (data == nullptr) {
std::this_thread::yield();
Expand All @@ -48,7 +49,7 @@ static void spsc(const size_t max_write_thread = 4) {
uint64_t read_count = 0;
while (read_count < limit) {
uint32_t size;
const auto data = static_cast<uint8_t*>(consumer.TryReadOnePacket(&size));
const auto data = static_cast<uint8_t*>(consumer.ReadOrWait(&size, 100));
if (data == nullptr) {
std::this_thread::yield();
continue;
Expand All @@ -70,8 +71,8 @@ static void spsc(const size_t max_write_thread = 4) {
TEST(MpscRingTest, singl_producer_single_consumer) {
LOGGER_TIME_CODE({ spsc<1 << 5>(16); });
LOGGER_TIME_CODE({ spsc<1 << 6>(16); });
LOGGER_TIME_CODE({ spsc<1 << 7>(16); });
LOGGER_TIME_CODE({ spsc<1 << 8>(16); });
LOGGER_TIME_CODE({ spsc<1 << 9>(16); });
LOGGER_TIME_CODE({ spsc<1 << 10>(16); });
// LOGGER_TIME_CODE({ spsc<1 << 7>(16); });
// LOGGER_TIME_CODE({ spsc<1 << 8>(16); });
// LOGGER_TIME_CODE({ spsc<1 << 9>(16); });
// LOGGER_TIME_CODE({ spsc<1 << 10>(16); });
}

0 comments on commit 123dbf8

Please sign in to comment.