Skip to content

Commit

Permalink
feat(qp): implement trait for QP
Browse files Browse the repository at this point in the history
Just like what we did with CQ

Refs:
- #14
- #15

Signed-off-by: Luke Yue <lukedyue@gmail.com>
  • Loading branch information
dragonJACson committed Sep 16, 2024
1 parent 815a5df commit 79b3b5e
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 35 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ rdma-mummy-sys = "0.1"
tabled = "0.16"
libc = "0.2"
os_socketaddr = "0.2"
bitmask-enum = "2.2"
16 changes: 11 additions & 5 deletions examples/test_qp.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use rdma_mummy_sys::ibv_access_flags;
use sideway::verbs::{
address::AddressHandleAttribute,
device,
device_context::Mtu,
queue_pair::{QueuePairAttribute, QueuePairState},
queue_pair::{QueuePair, QueuePairAttribute, QueuePairState},
AccessFlags,
};

fn main() -> Result<(), Box<dyn std::error::Error>> {
Expand All @@ -12,9 +12,9 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let ctx = device.open().unwrap();

let pd = ctx.alloc_pd().unwrap();
let mr = pd.reg_managed_mr(64).unwrap();
let _mr = pd.reg_managed_mr(64).unwrap();

let comp_channel = ctx.create_comp_channel().unwrap();
let _comp_channel = ctx.create_comp_channel().unwrap();
let mut cq_builder = ctx.create_cq_builder();
let sq = cq_builder.setup_cqe(128).build().unwrap();
let rq = cq_builder.setup_cqe(128).build().unwrap();
Expand All @@ -34,9 +34,11 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
attr.setup_state(QueuePairState::Init)
.setup_pkey_index(0)
.setup_port(1)
.setup_access_flags(ibv_access_flags::IBV_ACCESS_REMOTE_WRITE);
.setup_access_flags(AccessFlags::LocalWrite | AccessFlags::RemoteWrite);
qp.modify(&attr).unwrap();

assert_eq!(QueuePairState::Init, qp.state());

// modify QP to RTR state
let mut attr = QueuePairAttribute::new();
attr.setup_state(QueuePairState::ReadyToReceive)
Expand All @@ -59,6 +61,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
attr.setup_address_vector(&ah_attr);
qp.modify(&attr).unwrap();

assert_eq!(QueuePairState::ReadyToReceive, qp.state());

// modify QP to RTS state
let mut attr = QueuePairAttribute::new();
attr.setup_state(QueuePairState::ReadyToSend)
Expand All @@ -69,6 +73,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.setup_max_read_atomic(0);

qp.modify(&attr).unwrap();

assert_eq!(QueuePairState::ReadyToSend, qp.state());
}

Ok(())
Expand Down
19 changes: 19 additions & 0 deletions src/verbs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,22 @@ pub mod device_context;
pub mod memory_region;
pub mod protection_domain;
pub mod queue_pair;

use bitmask_enum::bitmask;
use rdma_mummy_sys::ibv_access_flags;

#[bitmask(i32)]
#[bitmask_config(vec_debug)]
pub enum AccessFlags {
LocalWrite = ibv_access_flags::IBV_ACCESS_LOCAL_WRITE.0 as _,
RemoteWrite = ibv_access_flags::IBV_ACCESS_REMOTE_WRITE.0 as _,
RemoteRead = ibv_access_flags::IBV_ACCESS_REMOTE_READ.0 as _,
RemoteAtomic = ibv_access_flags::IBV_ACCESS_REMOTE_ATOMIC.0 as _,
MemoryWindowsBind = ibv_access_flags::IBV_ACCESS_MW_BIND.0 as _,
ZeroBased = ibv_access_flags::IBV_ACCESS_ZERO_BASED.0 as _,
OnDemand = ibv_access_flags::IBV_ACCESS_ON_DEMAND.0 as _,
HugeTlb = ibv_access_flags::IBV_ACCESS_HUGETLB.0 as _,
FlushGlobal = ibv_access_flags::IBV_ACCESS_FLUSH_GLOBAL.0 as _,
FlushPersistent = ibv_access_flags::IBV_ACCESS_FLUSH_PERSISTENT.0 as _,
RelaxedOrdering = ibv_access_flags::IBV_ACCESS_RELAXED_ORDERING.0 as _,
}
10 changes: 9 additions & 1 deletion src/verbs/protection_domain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use super::{
device_context::DeviceContext,
memory_region::{Buffer, MemoryRegion},
queue_pair::QueuePairBuilder,
AccessFlags,
};

#[derive(Debug)]
Expand Down Expand Up @@ -35,7 +36,14 @@ impl ProtectionDomain<'_> {
pub fn reg_managed_mr(&self, size: usize) -> Result<MemoryRegion, String> {
let buf = Buffer::from_len_zeroed(size);

let mr = unsafe { ibv_reg_mr(self.pd.as_ptr(), buf.data.as_ptr() as _, buf.len, 0) };
let mr = unsafe {
ibv_reg_mr(
self.pd.as_ptr(),
buf.data.as_ptr() as _,
buf.len,
(AccessFlags::RemoteWrite | AccessFlags::LocalWrite).into(),
)
};

if mr.is_null() {
return Err(format!("{:?}", io::Error::last_os_error()));
Expand Down
153 changes: 124 additions & 29 deletions src/verbs/queue_pair.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use bitmask_enum::bitmask;
use rdma_mummy_sys::{
ibv_access_flags, ibv_create_qp, ibv_destroy_qp, ibv_modify_qp, ibv_qp, ibv_qp_attr, ibv_qp_attr_mask, ibv_qp_cap,
ibv_qp_ex, ibv_qp_init_attr, ibv_qp_init_attr_ex, ibv_qp_state, ibv_qp_type, ibv_rx_hash_conf,
ibv_create_qp, ibv_create_qp_ex, ibv_destroy_qp, ibv_modify_qp, ibv_qp, ibv_qp_attr, ibv_qp_attr_mask, ibv_qp_cap,
ibv_qp_create_send_ops_flags, ibv_qp_ex, ibv_qp_init_attr, ibv_qp_init_attr_ex, ibv_qp_init_attr_mask,
ibv_qp_state, ibv_qp_to_qp_ex, ibv_qp_type, ibv_rx_hash_conf,
};
use std::{
io,
Expand All @@ -11,7 +13,7 @@ use std::{

use super::{
address::AddressHandleAttribute, completion::CompletionQueue, device_context::Mtu,
protection_domain::ProtectionDomain,
protection_domain::ProtectionDomain, AccessFlags,
};

#[repr(u32)]
Expand All @@ -26,7 +28,7 @@ pub enum QueuePairType {
}

#[repr(u32)]
#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QueuePairState {
Reset = ibv_qp_state::IBV_QPS_RESET,
Init = ibv_qp_state::IBV_QPS_INIT,
Expand All @@ -38,26 +40,105 @@ pub enum QueuePairState {
Unknown = ibv_qp_state::IBV_QPS_UNKNOWN,
}

impl From<u32> for QueuePairState {
fn from(state: u32) -> Self {
match state {
ibv_qp_state::IBV_QPS_RESET => QueuePairState::Reset,
ibv_qp_state::IBV_QPS_INIT => QueuePairState::Init,
ibv_qp_state::IBV_QPS_RTR => QueuePairState::ReadyToReceive,
ibv_qp_state::IBV_QPS_RTS => QueuePairState::ReadyToSend,
ibv_qp_state::IBV_QPS_SQD => QueuePairState::SendQueueDrain,
ibv_qp_state::IBV_QPS_SQE => QueuePairState::SendQueueError,
ibv_qp_state::IBV_QPS_ERR => QueuePairState::Error,
ibv_qp_state::IBV_QPS_UNKNOWN => QueuePairState::Unknown,
_ => panic!("Unknown qp state: {state}"),
}
}
}

#[bitmask(u64)]
#[bitmask_config(vec_debug)]
pub enum SendOperationType {
Write = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_RDMA_WRITE.0 as _,
WriteWithImmediate = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_RDMA_WRITE_WITH_IMM.0 as _,
Send = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_SEND.0 as _,
SendWithImmediate = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_SEND_WITH_IMM.0 as _,
Read = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_RDMA_READ.0 as _,
AtomicCompareAndSwap = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_ATOMIC_CMP_AND_SWP.0 as _,
AtomicFetchAndAdd = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_ATOMIC_FETCH_AND_ADD.0 as _,
LocalInvalidate = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_LOCAL_INV.0 as _,
BindMemoryWindows = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_BIND_MW.0 as _,
SendWithInvalidate = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_SEND_WITH_INV.0 as _,
TcpSegmentationOffload = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_TSO.0 as _,
Flush = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_FLUSH.0 as _,
AtomicWrite = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_ATOMIC_WRITE.0 as _,
}

pub trait QueuePair {
//! return the basic handle of QP;
//! we mark this method unsafe because the lifetime of ibv_qp is not
//! associated with the return value.
unsafe fn qp(&self) -> NonNull<ibv_qp>;

fn modify(&mut self, attr: &QueuePairAttribute) -> Result<(), String> {
// ibv_qp_attr does not impl Clone trait, so we use struct update syntax here
let mut qp_attr = ibv_qp_attr { ..attr.attr };
let ret = unsafe { ibv_modify_qp(self.qp().as_ptr(), &mut qp_attr as *mut _, attr.attr_mask.0 as _) };
if ret == 0 {
Ok(())
} else {
Err(format!("ibv_modify_qp failed, err={ret}"))
}
}

fn state(&self) -> QueuePairState {
unsafe { (*self.qp().as_ref()).state.into() }
}
}

#[derive(Debug)]
pub struct QueuePair<'res> {
pub struct BasicQueuePair<'res> {
pub(crate) qp: NonNull<ibv_qp>,
// phantom data for protection domain & completion queues
_phantom: PhantomData<&'res ()>,
}

impl Drop for QueuePair<'_> {
impl Drop for BasicQueuePair<'_> {
fn drop(&mut self) {
let ret = unsafe { ibv_destroy_qp(self.qp.as_ptr()) };
assert_eq!(ret, 0);
}
}

pub struct QueuePairExtended<'res> {
impl QueuePair for BasicQueuePair<'_> {
unsafe fn qp(&self) -> NonNull<ibv_qp> {
self.qp
}
}

#[derive(Debug)]
pub struct ExtendedQueuePair<'res> {
pub(crate) qp_ex: NonNull<ibv_qp_ex>,
// phantom data for protection domain & completion queues
_phantom: PhantomData<&'res ()>,
}

impl Drop for ExtendedQueuePair<'_> {
fn drop(&mut self) {
let ret = unsafe { ibv_destroy_qp(self.qp_ex.as_ptr().cast()) };
assert_eq!(ret, 0)
}
}

impl QueuePair for ExtendedQueuePair<'_> {
unsafe fn qp(&self) -> NonNull<ibv_qp> {
self.qp_ex.cast()
}
}

pub struct QueuePairBuilder<'res> {
init_attr: ibv_qp_init_attr_ex,
// phantom data for protection domain & completion queues
_phantom: PhantomData<&'res ()>,
}

Expand Down Expand Up @@ -141,7 +222,14 @@ impl<'res> QueuePairBuilder<'res> {
self
}

pub fn build(&self) -> Result<QueuePair<'res>, String> {
pub fn setup_send_ops_flags(&mut self, send_ops_flags: SendOperationType) -> &mut Self {
self.init_attr.send_ops_flags = send_ops_flags.bits;
self.init_attr.comp_mask |= ibv_qp_init_attr_mask::IBV_QP_INIT_ATTR_SEND_OPS_FLAGS.0;
self
}

// build basic qp
pub fn build(&self) -> Result<BasicQueuePair<'res>, String> {
let qp = unsafe {
ibv_create_qp(
self.init_attr.pd,
Expand All @@ -157,31 +245,39 @@ impl<'res> QueuePairBuilder<'res> {
)
};

Ok(QueuePair {
Ok(BasicQueuePair {
qp: NonNull::new(qp).ok_or(format!("ibv_create_qp failed, {}", io::Error::last_os_error()))?,
_phantom: PhantomData,
})
}

pub fn build_ex() -> () {
todo!();
}
}
// build extended qp
pub fn build_ex(&self) -> Result<ExtendedQueuePair<'res>, String> {
let mut attr = self.init_attr.clone();

// to build a real extended qp instead of a basic qp, we need to pass in
// these essential attributes.
attr.comp_mask |=
ibv_qp_init_attr_mask::IBV_QP_INIT_ATTR_PD.0 | ibv_qp_init_attr_mask::IBV_QP_INIT_ATTR_SEND_OPS_FLAGS.0;

// unless user specified, we assume every extended qp would support send,
// write and read, just as what basic qp supports.
if attr.send_ops_flags == 0 {
attr.send_ops_flags = (SendOperationType::Send
| SendOperationType::SendWithImmediate
| SendOperationType::Write
| SendOperationType::WriteWithImmediate
| SendOperationType::Read)
.into()
}

impl QueuePair<'_> {
pub(crate) fn new<'pd>(pd: &'pd ProtectionDomain) -> Self {
todo!()
}
let qp = unsafe { ibv_create_qp_ex((*(attr.pd)).context, &mut attr).unwrap_or(null_mut()) };

pub fn modify(&mut self, attr: &QueuePairAttribute) -> Result<(), String> {
// ibv_qp_attr does not impl Clone trait, so we use struct update syntax here
let mut qp_attr = ibv_qp_attr { ..attr.attr };
let ret = unsafe { ibv_modify_qp(self.qp.as_ptr(), &mut qp_attr as *mut _, attr.attr_mask.0 as _) };
if ret == 0 {
Ok(())
} else {
Err(format!("ibv_modify_qp failed, err={ret}"))
}
Ok(ExtendedQueuePair {
qp_ex: NonNull::new(unsafe { ibv_qp_to_qp_ex(qp) })
.ok_or(format!("ibv_create_qp_ex failed, {}", io::Error::last_os_error()))?,
_phantom: PhantomData,
})
}
}

Expand Down Expand Up @@ -226,9 +322,8 @@ impl QueuePairAttribute {
self
}

// TODO(fuji): use ibv_access_flags directly or wrap a type for this?
pub fn setup_access_flags(&mut self, access_flags: ibv_access_flags) -> &mut Self {
self.attr.qp_access_flags = access_flags.0;
pub fn setup_access_flags(&mut self, access_flags: AccessFlags) -> &mut Self {
self.attr.qp_access_flags = access_flags.bits as _;
self.attr_mask |= ibv_qp_attr_mask::IBV_QP_ACCESS_FLAGS;
self
}
Expand Down

0 comments on commit 79b3b5e

Please sign in to comment.