32
32
/* Message buffer size -- maximum span of simultaneous inflight messages */
33
33
#define NCCL_OFI_RDMA_MSGBUFF_SIZE 256
34
34
35
+ static_assert (NCCL_OFI_RDMA_MSGBUFF_SIZE > NCCL_OFI_MAX_REQUESTS ,
36
+ "Message buffer size must be larger than max_requests" );
37
+
35
38
/* Maximum number of comms open simultaneously. Eventually this will be
36
39
runtime-expandable */
37
40
#define NCCL_OFI_RDMA_MAX_COMMS (1 << NCCL_OFI_RDMA_COMM_ID_BITS)
91
94
#define GET_RDMA_WRITE_IMM_DATA (comm_id , seq , nseg ) \
92
95
((seq) | ((comm_id) << NCCL_OFI_RDMA_SEQ_BITS) | ((nseg) << (NCCL_OFI_RDMA_SEQ_BITS + NCCL_OFI_RDMA_COMM_ID_BITS)))
93
96
97
+ static inline uint16_t msg_seq_num_distance (uint16_t right , uint16_t left )
98
+ {
99
+ assert ((right <= MSG_SEQ_NUM_MASK ) && (left <= MSG_SEQ_NUM_MASK ));
100
+ return (right - left ) & MSG_SEQ_NUM_MASK ;
101
+ }
102
+
103
+ static inline bool receiver_out_of_sync (uint16_t msg_seq_num , uint16_t recv_last_seq_num )
104
+ {
105
+ /**
106
+ * If receiver has started message n, then the last incomplete is no
107
+ * less than n - max_requests + 1, and so the receiver's msgbuff has
108
+ * space for at least message (last_incomplete + msgbuff_size - 1)
109
+ *
110
+ * Therefore, the condition for sender to wait for receiver is:
111
+ *
112
+ * distance(msg_seq_num, last_ctrl_received) == msgbuff_size - max_requests
113
+ */
114
+ return msg_seq_num_distance (msg_seq_num , recv_last_seq_num ) ==
115
+ (NCCL_OFI_RDMA_MSGBUFF_SIZE - NCCL_OFI_MAX_REQUESTS );
116
+ }
117
+
94
118
/** Global variables **/
95
119
96
120
/* Maximum size of an eager message (see OFI_NCCL_EAGER_MAX_SIZE) */
@@ -1035,7 +1059,7 @@ static inline int handle_eager_recv(nccl_net_ofi_rdma_recv_comm_t *r_comm,
1035
1059
}
1036
1060
1037
1061
if (OFI_UNLIKELY (stat != NCCL_OFI_MSGBUFF_INPROGRESS )) {
1038
- NCCL_OFI_WARN ("Unexpected message status (%d) (ctrl recv) " , (int )stat );
1062
+ NCCL_OFI_WARN ("Unexpected message status (%d) for msg %hu " , (int )stat , msg_seq_num );
1039
1063
return - EINVAL ;
1040
1064
}
1041
1065
@@ -1213,6 +1237,11 @@ static inline int handle_bounce_recv(nccl_net_ofi_rdma_device_t *device, int rai
1213
1237
1214
1238
nccl_net_ofi_mutex_lock (& s_comm -> ctrl_recv_lock );
1215
1239
s_comm -> n_ctrl_received += 1 ;
1240
+ if (msg_seq_num_distance ((uint16_t )ctrl_msg -> msg_seq_num ,
1241
+ s_comm -> last_ctrl_received )
1242
+ <= NCCL_OFI_RDMA_MSGBUFF_SIZE ) {
1243
+ s_comm -> last_ctrl_received = ctrl_msg -> msg_seq_num ;
1244
+ }
1216
1245
nccl_net_ofi_mutex_unlock (& s_comm -> ctrl_recv_lock );
1217
1246
1218
1247
break ;
@@ -3286,6 +3315,8 @@ static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers,
3286
3315
3287
3316
NCCL_OFI_TRACE_RECV (dev_id , r_comm -> local_comm_id , sizes [0 ], req , base_req );
3288
3317
3318
+ bool send_ctrl = false;
3319
+
3289
3320
if (eager ) {
3290
3321
if (recv_data -> eager_copy_req == NULL ) {
3291
3322
/* If we don't need to do eager copy, this recv is already complete */
@@ -3302,7 +3333,20 @@ static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers,
3302
3333
goto error ;
3303
3334
}
3304
3335
}
3336
+
3337
+ /* Send a ctrl message if we haven't sent one recently enough
3338
+ (see corresponding check in send) */
3339
+ if (receiver_out_of_sync (msg_seq_num , r_comm -> last_ctrl_sent )) {
3340
+ send_ctrl = true;
3341
+ } else {
3342
+ send_ctrl = false;
3343
+ }
3305
3344
} else {
3345
+ send_ctrl = true;
3346
+ }
3347
+
3348
+ if (send_ctrl )
3349
+ {
3306
3350
nccl_net_ofi_rdma_req_t * send_ctrl_req =
3307
3351
allocate_send_ctrl_req (r_comm , device , dev_id , msg_seq_num ,
3308
3352
buffers [0 ], sizes [0 ], mr_handles [0 ]);
@@ -3322,6 +3366,8 @@ static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers,
3322
3366
/* TODO: Remove req from message buffer */
3323
3367
goto error ;
3324
3368
}
3369
+
3370
+ r_comm -> last_ctrl_sent = msg_seq_num ;
3325
3371
}
3326
3372
3327
3373
/* Return request to NCCL */
@@ -4173,6 +4219,7 @@ static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_listen
4173
4219
memset (& r_comm -> cleanup_list_elem , 0 , sizeof (r_comm -> cleanup_list_elem ));
4174
4220
r_comm -> n_ctrl_sent = 0 ;
4175
4221
r_comm -> n_ctrl_delivered = 0 ;
4222
+ r_comm -> last_ctrl_sent = (uint16_t )((-1 ) & MSG_SEQ_NUM_MASK );
4176
4223
4177
4224
/* Allocate recv communicator ID */
4178
4225
comm_id = nccl_ofi_idpool_allocate_id (device -> comm_idpool );
@@ -5389,6 +5436,10 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t
5389
5436
goto error ;
5390
5437
}
5391
5438
5439
+ nccl_net_ofi_mutex_lock (& s_comm -> ctrl_recv_lock );
5440
+ uint16_t last_ctrl_received = s_comm -> last_ctrl_received ;
5441
+ nccl_net_ofi_mutex_unlock (& s_comm -> ctrl_recv_lock );
5442
+
5392
5443
dev_id = s_comm -> base .base .dev_id ;
5393
5444
5394
5445
ep = (nccl_net_ofi_rdma_ep_t * )s_comm -> base .base .ep ;
@@ -5405,6 +5456,12 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t
5405
5456
goto error ;
5406
5457
}
5407
5458
5459
+ if (receiver_out_of_sync (msg_seq_num , last_ctrl_received )) {
5460
+ /* Wait for receiver to catch up */
5461
+ ret = ofi_process_cq (ep );
5462
+ goto free_req ;
5463
+ }
5464
+
5408
5465
/*
5409
5466
* TODO: Use NCCL provided tags when using grouped receives aka
5410
5467
* props->maxRecvs > 1.
@@ -5940,6 +5997,7 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle,
5940
5997
ret_s_comm -> received_close_message = false;
5941
5998
ret_s_comm -> n_ctrl_received = 0 ;
5942
5999
ret_s_comm -> n_ctrl_expected = 0 ;
6000
+ ret_s_comm -> last_ctrl_received = (uint16_t )((-1 ) & MSG_SEQ_NUM_MASK );
5943
6001
5944
6002
/* Store communicator ID from handle in communicator */
5945
6003
if (OFI_UNLIKELY (handle -> comm_id >= device -> num_comm_ids )) {
0 commit comments