Skip to content

Commit

Permalink
Merge pull request #2 from kvcache-ai/rf_dev
Browse files Browse the repository at this point in the history
reorganize transfer engine to be more structured
  • Loading branch information
alogfans authored Jul 26, 2024
2 parents 12f4011 + 95e105e commit abda1b5
Show file tree
Hide file tree
Showing 13 changed files with 902 additions and 836 deletions.
4 changes: 1 addition & 3 deletions example/transfer_engine_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ DEFINE_string(operation, "read", "Operation type: read or write");
// "cpu:1": [["mlx5_3"], ["mlx5_2"]],
// "cuda:0": [["mlx5_2"], ["mlx5_3"]],
// }
DEFINE_string(nic_priority_matrix, "{\"cpu:0\": [[\"mlx5_2\"], [\"mlx5_3\"]], \"cpu:1\": [[\"mlx5_3\"], [\"mlx5_2\"]]}", "NIC priority matrix");
DEFINE_string(nic_priority_matrix, "{\"cpu:0\": [[\"mlx5_2\", \"mlx5_3\"], []], \"cpu:1\": [[\"mlx5_3\"], [\"mlx5_2\"]]}", "NIC priority matrix");
DEFINE_string(segment_id, "optane20", "Segment ID to access data");
DEFINE_int32(batch_size, 128, "Batch size");
DEFINE_int32(block_size, 4096, "Block size for each transfer request");
Expand Down Expand Up @@ -129,7 +129,6 @@ int initiator()
getHostname(),
FLAGS_nic_priority_matrix);
LOG_ASSERT(engine);
engine->updateRnicLinkSpeed({200, 100});

void *addr = allocateMemoryPool(dram_buffer_size);
engine->registerLocalMemory(addr, dram_buffer_size, "cpu:0");
Expand Down Expand Up @@ -179,7 +178,6 @@ int target()
getHostname(),
FLAGS_nic_priority_matrix);
LOG_ASSERT(engine);
engine->updateRnicLinkSpeed({200, 100});

void *addr = allocateMemoryPool(dram_buffer_size);
engine->registerLocalMemory(addr, dram_buffer_size, "cpu:0");
Expand Down
3 changes: 2 additions & 1 deletion src/transfer_engine/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
add_library(transfer_engine transfer_engine.cpp
transfer_metadata.cpp
rdma_context.cpp
rdma_endpoint.cpp)
rdma_endpoint.cpp
worker_pool.cpp)
target_link_libraries(transfer_engine PUBLIC ibverbs glog gflags pthread memcached jsoncpp)
134 changes: 96 additions & 38 deletions src/transfer_engine/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <cstdint>
#include <ctime>
#include <atomic>
#include <thread>
#include <sys/mman.h>
#include <numa.h>

Expand Down Expand Up @@ -131,65 +132,123 @@ namespace mooncake

class RWSpinlock
{
union RWTicket
{
constexpr RWTicket() : whole(0) {}
uint64_t whole;
uint32_t readWrite;
struct
{
uint16_t write;
uint16_t read;
uint16_t users;
};
} ticket;

private:
static void asm_volatile_memory()
{
asm volatile("" ::: "memory");
}

template <class T>
static T load_acquire(T *addr)
{
T t = *addr;
asm_volatile_memory();
return t;
}

template <class T>
static void store_release(T *addr, T v)
{
asm_volatile_memory();
*addr = v;
}

public:
RWSpinlock() : lock_(0) {}
RWSpinlock() {}

~RWSpinlock() {}
RWSpinlock(RWSpinlock const &) = delete;
RWSpinlock &operator=(RWSpinlock const &) = delete;

RWSpinlock(const RWSpinlock &) = delete;
void lock()
{
writeLockNice();
}

RWSpinlock &operator=(const RWSpinlock &) = delete;
bool tryLock()
{
RWTicket t;
uint64_t old = t.whole = load_acquire(&ticket.whole);
if (t.users != t.write)
return false;
++t.users;
return __sync_bool_compare_and_swap(&ticket.whole, old, t.whole);
}

void RLock()
void writeLockAggressive()
{
while (true)
uint32_t count = 0;
uint16_t val = __sync_fetch_and_add(&ticket.users, 1);
while (val != load_acquire(&ticket.write))
{
int64_t lock = lock_.fetch_add(1, std::memory_order_relaxed);
if (lock >= 0)
break;
lock_.fetch_sub(1, std::memory_order_relaxed);
PAUSE();
if (++count > 1000)
std::this_thread::yield();
}
std::atomic_thread_fence(std::memory_order_acquire);
}

void RUnlock()
void writeLockNice()
{
std::atomic_thread_fence(std::memory_order_release);
int64_t lock = lock_.fetch_sub(1, std::memory_order_relaxed);
LOG_ASSERT(lock > 0);
while (!tryLock())
;
}

void WLock()
void unlockAndLockShared()
{
while (true)
{
int64_t lock;
while ((lock = lock_.load(std::memory_order_relaxed)))
PAUSE();
if (lock_.compare_exchange_weak(lock, kExclusiveLock, std::memory_order_relaxed))
break;
}
std::atomic_thread_fence(std::memory_order_acquire);
uint16_t val = __sync_fetch_and_add(&ticket.read, 1);
(void)val;
}

void unlock()
{
RWTicket t;
t.whole = load_acquire(&ticket.whole);
++t.read;
++t.write;
store_release(&ticket.readWrite, t.readWrite);
}

void WUnlock()
void lockShared()
{
while (true)
uint_fast32_t count = 0;
while (!tryLockShared())
{
int64_t lock;
while ((lock = lock_.load(std::memory_order_relaxed)) != kExclusiveLock)
PAUSE();
std::atomic_thread_fence(std::memory_order_release);
if (lock_.compare_exchange_weak(lock, 0, std::memory_order_relaxed))
return;
_mm_pause();
if ((++count & 1023) == 0)
std::this_thread::yield();
}
}

bool tryLockShared()
{
RWTicket t, old;
old.whole = t.whole = load_acquire(&ticket.whole);
old.users = old.read;
++t.read;
++t.users;
return __sync_bool_compare_and_swap(&ticket.whole, old.whole, t.whole);
}

void unlockShared() { __sync_fetch_and_add(&ticket.write, 1); }

public:
struct WriteGuard
{
WriteGuard(RWSpinlock &lock) : lock(lock)
{
lock.WLock();
lock.lock();
}

WriteGuard(const WriteGuard &) = delete;
Expand All @@ -198,7 +257,7 @@ namespace mooncake

~WriteGuard()
{
lock.WUnlock();
lock.unlock();
}

RWSpinlock &lock;
Expand All @@ -208,7 +267,7 @@ namespace mooncake
{
ReadGuard(RWSpinlock &lock) : lock(lock)
{
lock.RLock();
lock.lockShared();
}

ReadGuard(const ReadGuard &) = delete;
Expand All @@ -217,7 +276,7 @@ namespace mooncake

~ReadGuard()
{
lock.RUnlock();
lock.unlockShared();
}

RWSpinlock &lock;
Expand All @@ -229,7 +288,6 @@ namespace mooncake
std::atomic<int64_t> lock_;
uint64_t padding_[15];
};

}

#endif // COMMON_H
Loading

0 comments on commit abda1b5

Please sign in to comment.