Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: introduce PostSendGuard for extended QP with basic support for polling extended CQ #22

Merged
merged 3 commits into from
Sep 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,6 @@ libc = "0.2"
os_socketaddr = "0.2"
bitmask-enum = "2.2"
lazy_static = "1.5.0"

[dev-dependencies]
trybuild = "1.0"
113 changes: 113 additions & 0 deletions examples/test_post_send.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
use core::time;
use std::thread;

use sideway::verbs::{
address::AddressHandleAttribute,
device,
device_context::Mtu,
queue_pair::{PostSendGuard, QueuePair, QueuePairAttribute, QueuePairState, SetInlineData, WorkRequestFlags},
AccessFlags,
};

fn main() -> Result<(), Box<dyn std::error::Error>> {
let device_list = device::DeviceList::new()?;
for device in &device_list {
let ctx = device.open().unwrap();

let pd = ctx.alloc_pd().unwrap();
let mr = pd.reg_managed_mr(64).unwrap();

let _comp_channel = ctx.create_comp_channel().unwrap();
let mut cq_builder = ctx.create_cq_builder();
let sq = cq_builder.setup_cqe(128).build_ex().unwrap();
let rq = cq_builder.setup_cqe(128).build_ex().unwrap();

let mut builder = pd.create_qp_builder();

let mut qp = builder
.setup_max_inline_data(128)
.setup_send_cq(&sq)
.setup_recv_cq(&rq)
.build_ex()
.unwrap();

println!("qp pointer is {:?}", qp);
// modify QP to INIT state
let mut attr = QueuePairAttribute::new();
attr.setup_state(QueuePairState::Init)
.setup_pkey_index(0)
.setup_port(1)
.setup_access_flags(AccessFlags::LocalWrite | AccessFlags::RemoteWrite);
qp.modify(&attr).unwrap();

assert_eq!(QueuePairState::Init, qp.state());

// modify QP to RTR state, set dest qp as itself
let mut attr = QueuePairAttribute::new();
attr.setup_state(QueuePairState::ReadyToReceive)
.setup_path_mtu(Mtu::Mtu1024)
.setup_dest_qp_num(qp.qp_number())
.setup_rq_psn(1)
.setup_max_dest_read_atomic(0)
.setup_min_rnr_timer(0);
// setup address vector
let mut ah_attr = AddressHandleAttribute::new();
let gid_entries = ctx.query_gid_table().unwrap();

ah_attr
.setup_dest_lid(1)
.setup_port(1)
.setup_service_level(1)
.setup_grh_src_gid_index(gid_entries[0].gid_index().try_into().unwrap())
.setup_grh_dest_gid(&gid_entries[0].gid())
.setup_grh_hop_limit(64);
attr.setup_address_vector(&ah_attr);
qp.modify(&attr).unwrap();

assert_eq!(QueuePairState::ReadyToReceive, qp.state());

// modify QP to RTS state
let mut attr = QueuePairAttribute::new();
attr.setup_state(QueuePairState::ReadyToSend)
.setup_sq_psn(1)
.setup_timeout(12)
.setup_retry_cnt(7)
.setup_rnr_retry(7)
.setup_max_read_atomic(0);

qp.modify(&attr).unwrap();

assert_eq!(QueuePairState::ReadyToSend, qp.state());

let mut guard = qp.start_post_send();
let buf = vec![0, 1, 2, 3];

let write_handle = guard
.construct_wr(233, WorkRequestFlags::Signaled)
.setup_write(mr.rkey(), mr.buf.data.as_ptr() as _);

write_handle.setup_inline_data(&buf);

let _err = guard.post().unwrap();

thread::sleep(time::Duration::from_millis(10));

// poll for the completion
{
let mut poller = sq.start_poll().unwrap();
let mut wc = poller.iter_mut();
println!("wr_id {}, status: {}, opcode: {}", wc.wr_id(), wc.status(), wc.opcode());
assert_eq!(wc.wr_id(), 233);
while let Some(wc) = wc.next() {
println!("wr_id {}, status: {}, opcode: {}", wc.wr_id(), wc.status(), wc.opcode())
}
}

unsafe {
let slice = std::slice::from_raw_parts(mr.buf.data.as_ptr(), mr.buf.len);
println!("Buffer contents: {:?}", slice);
}
}

Ok(())
}
93 changes: 91 additions & 2 deletions src/verbs/completion.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use std::marker::PhantomData;
use std::os::raw::c_void;
use std::ptr;
use std::ptr::NonNull;
use std::{marker::PhantomData, mem::MaybeUninit};

use super::device_context::DeviceContext;
use rdma_mummy_sys::{
ibv_comp_channel, ibv_cq, ibv_cq_ex, ibv_cq_init_attr_ex, ibv_create_comp_channel, ibv_create_cq, ibv_create_cq_ex,
ibv_destroy_comp_channel, ibv_destroy_cq, ibv_pd,
ibv_destroy_comp_channel, ibv_destroy_cq, ibv_end_poll, ibv_next_poll, ibv_pd, ibv_poll_cq_attr, ibv_start_poll,
ibv_wc_read_byte_len, ibv_wc_read_completion_ts, ibv_wc_read_opcode, ibv_wc_read_vendor_err,
};

#[derive(Debug)]
Expand Down Expand Up @@ -85,6 +86,25 @@ impl CompletionQueue for ExtendedCompletionQueue<'_> {
}
}

impl ExtendedCompletionQueue<'_> {
pub fn start_poll<'cq>(&'cq self) -> Result<ExtendedPoller<'cq>, String> {
let ret = unsafe {
ibv_start_poll(
self.cq_ex.as_ptr(),
MaybeUninit::<ibv_poll_cq_attr>::zeroed().as_mut_ptr(),
)
};

match ret {
0 => Ok(ExtendedPoller {
cq: self.cq_ex,
_phantom: PhantomData,
}),
err => Err(format!("ibv_start_poll failed, ret={err}")),
}
}
}

// generic builder for both cq and cq_ex
pub struct CompletionQueueBuilder<'res> {
dev_ctx: &'res DeviceContext,
Expand Down Expand Up @@ -131,6 +151,7 @@ impl<'res> CompletionQueueBuilder<'res> {
self.init_attr.comp_vector = comp_vector;
self
}

// TODO(fuji): set various attributes

// build extended cq
Expand Down Expand Up @@ -167,3 +188,71 @@ impl<'res> CompletionQueueBuilder<'res> {
}

// TODO trait for both cq and cq_ex?

pub struct ExtendedWorkCompletion<'cq> {
cq: NonNull<ibv_cq_ex>,
_phantom: PhantomData<&'cq ()>,
}

impl<'cq> ExtendedWorkCompletion<'cq> {
pub fn wr_id(&self) -> u64 {
unsafe { self.cq.as_ref().wr_id }
}

pub fn status(&self) -> u32 {
unsafe { self.cq.as_ref().status }
}

pub fn opcode(&self) -> u32 {
unsafe { ibv_wc_read_opcode(self.cq.as_ptr()) }
}

pub fn vendor_err(&self) -> u32 {
unsafe { ibv_wc_read_vendor_err(self.cq.as_ptr()) }
}

pub fn byte_len(&self) -> u32 {
unsafe { ibv_wc_read_byte_len(self.cq.as_ptr()) }
}

pub fn completion_timestamp(&self) -> u64 {
unsafe { ibv_wc_read_completion_ts(self.cq.as_ptr()) }
}
}

pub struct ExtendedPoller<'cq> {
cq: NonNull<ibv_cq_ex>,
_phantom: PhantomData<&'cq ()>,
}

impl ExtendedPoller<'_> {
pub fn iter_mut(&mut self) -> ExtendedWorkCompletion {
ExtendedWorkCompletion {
cq: self.cq,
_phantom: PhantomData,
}
}
}

impl<'a> Iterator for ExtendedWorkCompletion<'a> {
type Item = ExtendedWorkCompletion<'a>;

fn next(&mut self) -> Option<Self::Item> {
let ret = unsafe { ibv_next_poll(self.cq.as_ptr()) };

if ret != 0 {
None
} else {
Some(ExtendedWorkCompletion {
cq: self.cq,
_phantom: PhantomData,
})
}
}
}

impl Drop for ExtendedPoller<'_> {
fn drop(&mut self) {
unsafe { ibv_end_poll(self.cq.as_ptr()) }
}
}
Loading