Skip to content

Commit bb32ff8

Browse files
committedOct 2, 2024
fix(rdma): send periodic control messages to sync sender/receiver
We hit a bug where the sender sends a long run of eager messages and ends up outpacing the receiver by more than the width of the message buffer, causing an error. In this patch, receiver will send a control message to sender at least every `msgbuff_size - max_requests` messages, and sender will pause if it hasn't received a control message within this duration. Signed-off-by: Eric Raut <eraut@amazon.com>
1 parent 40cb9d0 commit bb32ff8

File tree

2 files changed

+62
-1
lines changed

2 files changed

+62
-1
lines changed
 

‎include/nccl_ofi_rdma.h

+3
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,7 @@ typedef struct nccl_net_ofi_rdma_send_comm {
535535
/* Counters for total sent and received control messages */
536536
uint64_t n_ctrl_received;
537537
uint64_t n_ctrl_expected;
538+
uint16_t last_ctrl_received;
538539

539540
bool comm_active;
540541

@@ -615,6 +616,8 @@ typedef struct nccl_net_ofi_rdma_recv_comm {
615616
uint64_t n_ctrl_sent;
616617
uint64_t n_ctrl_delivered;
617618

619+
uint16_t last_ctrl_sent;
620+
618621
/* Number of rails */
619622
int num_rails;
620623

‎src/nccl_ofi_rdma.c

+59-1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232
/* Message buffer size -- maximum span of simultaneous inflight messages */
3333
#define NCCL_OFI_RDMA_MSGBUFF_SIZE 256
3434

35+
static_assert(NCCL_OFI_RDMA_MSGBUFF_SIZE > NCCL_OFI_MAX_REQUESTS,
36+
"Message buffer size must be larger than max_requests");
37+
3538
/* Maximum number of comms open simultaneously. Eventually this will be
3639
runtime-expandable */
3740
#define NCCL_OFI_RDMA_MAX_COMMS (1 << NCCL_OFI_RDMA_COMM_ID_BITS)
@@ -91,6 +94,27 @@
9194
#define GET_RDMA_WRITE_IMM_DATA(comm_id, seq, nseg) \
9295
((seq) | ((comm_id) << NCCL_OFI_RDMA_SEQ_BITS) | ((nseg) << (NCCL_OFI_RDMA_SEQ_BITS + NCCL_OFI_RDMA_COMM_ID_BITS)))
9396

97+
static inline uint16_t msg_seq_num_distance(uint16_t right, uint16_t left)
98+
{
99+
assert((right <= MSG_SEQ_NUM_MASK) && (left <= MSG_SEQ_NUM_MASK));
100+
return (right - left) & MSG_SEQ_NUM_MASK;
101+
}
102+
103+
static inline bool receiver_out_of_sync(uint16_t msg_seq_num, uint16_t recv_last_seq_num)
104+
{
105+
/**
106+
* If receiver has started message n, then the last incomplete is no
107+
* less than n - max_requests + 1, and so the receiver's msgbuff has
108+
* space for at least message (last_incomplete + msgbuff_size - 1)
109+
*
110+
* Therefore, the condition for sender to wait for receiver is:
111+
*
112+
* distance(msg_seq_num, last_ctrl_received) == msgbuff_size - max_requests
113+
*/
114+
return msg_seq_num_distance(msg_seq_num, recv_last_seq_num) ==
115+
(NCCL_OFI_RDMA_MSGBUFF_SIZE - NCCL_OFI_MAX_REQUESTS);
116+
}
117+
94118
/** Global variables **/
95119

96120
/* Maximum size of an eager message (see OFI_NCCL_EAGER_MAX_SIZE) */
@@ -1035,7 +1059,7 @@ static inline int handle_eager_recv(nccl_net_ofi_rdma_recv_comm_t *r_comm,
10351059
}
10361060

10371061
if (OFI_UNLIKELY(stat != NCCL_OFI_MSGBUFF_INPROGRESS)) {
1038-
NCCL_OFI_WARN("Unexpected message status (%d) (ctrl recv)", (int)stat);
1062+
NCCL_OFI_WARN("Unexpected message status (%d) for msg %hu", (int)stat, msg_seq_num);
10391063
return -EINVAL;
10401064
}
10411065

@@ -1213,6 +1237,11 @@ static inline int handle_bounce_recv(nccl_net_ofi_rdma_device_t *device, int rai
12131237

12141238
nccl_net_ofi_mutex_lock(&s_comm->ctrl_recv_lock);
12151239
s_comm->n_ctrl_received += 1;
1240+
if (msg_seq_num_distance((uint16_t)ctrl_msg->msg_seq_num,
1241+
s_comm->last_ctrl_received)
1242+
<= NCCL_OFI_RDMA_MSGBUFF_SIZE) {
1243+
s_comm->last_ctrl_received = ctrl_msg->msg_seq_num;
1244+
}
12161245
nccl_net_ofi_mutex_unlock(&s_comm->ctrl_recv_lock);
12171246

12181247
break;
@@ -3286,6 +3315,8 @@ static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers,
32863315

32873316
NCCL_OFI_TRACE_RECV(dev_id, r_comm->local_comm_id, sizes[0], req, base_req);
32883317

3318+
bool send_ctrl = false;
3319+
32893320
if (eager) {
32903321
if (recv_data->eager_copy_req == NULL) {
32913322
/* If we don't need to do eager copy, this recv is already complete */
@@ -3302,7 +3333,20 @@ static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers,
33023333
goto error;
33033334
}
33043335
}
3336+
3337+
/* Send a ctrl message if we haven't sent one recently enough
3338+
(see corresponding check in send) */
3339+
if (receiver_out_of_sync(msg_seq_num, r_comm->last_ctrl_sent)) {
3340+
send_ctrl = true;
3341+
} else {
3342+
send_ctrl = false;
3343+
}
33053344
} else {
3345+
send_ctrl = true;
3346+
}
3347+
3348+
if (send_ctrl)
3349+
{
33063350
nccl_net_ofi_rdma_req_t *send_ctrl_req =
33073351
allocate_send_ctrl_req(r_comm, device, dev_id, msg_seq_num,
33083352
buffers[0], sizes[0], mr_handles[0]);
@@ -3322,6 +3366,8 @@ static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers,
33223366
/* TODO: Remove req from message buffer */
33233367
goto error;
33243368
}
3369+
3370+
r_comm->last_ctrl_sent = msg_seq_num;
33253371
}
33263372

33273373
/* Return request to NCCL */
@@ -4173,6 +4219,7 @@ static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_listen
41734219
memset(&r_comm->cleanup_list_elem, 0, sizeof(r_comm->cleanup_list_elem));
41744220
r_comm->n_ctrl_sent = 0;
41754221
r_comm->n_ctrl_delivered = 0;
4222+
r_comm->last_ctrl_sent = (uint16_t)((-1) & MSG_SEQ_NUM_MASK);
41764223

41774224
/* Allocate recv communicator ID */
41784225
comm_id = nccl_ofi_idpool_allocate_id(device->comm_idpool);
@@ -5389,6 +5436,10 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t
53895436
goto error;
53905437
}
53915438

5439+
nccl_net_ofi_mutex_lock(&s_comm->ctrl_recv_lock);
5440+
uint16_t last_ctrl_received = s_comm->last_ctrl_received;
5441+
nccl_net_ofi_mutex_unlock(&s_comm->ctrl_recv_lock);
5442+
53925443
dev_id = s_comm->base.base.dev_id;
53935444

53945445
ep = (nccl_net_ofi_rdma_ep_t *)s_comm->base.base.ep;
@@ -5405,6 +5456,12 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t
54055456
goto error;
54065457
}
54075458

5459+
if (receiver_out_of_sync(msg_seq_num, last_ctrl_received)) {
5460+
/* Wait for receiver to catch up */
5461+
ret = ofi_process_cq(ep);
5462+
goto free_req;
5463+
}
5464+
54085465
/*
54095466
* TODO: Use NCCL provided tags when using grouped receives aka
54105467
* props->maxRecvs > 1.
@@ -5940,6 +5997,7 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle,
59405997
ret_s_comm->received_close_message = false;
59415998
ret_s_comm->n_ctrl_received = 0;
59425999
ret_s_comm->n_ctrl_expected = 0;
6000+
ret_s_comm->last_ctrl_received = (uint16_t)((-1) & MSG_SEQ_NUM_MASK);
59436001

59446002
/* Store communicator ID from handle in communicator */
59456003
if (OFI_UNLIKELY(handle->comm_id >= device->num_comm_ids)) {

0 commit comments

Comments
 (0)