diff --git a/src/team_lib/mhba/xccl_mhba_collective.c b/src/team_lib/mhba/xccl_mhba_collective.c index d8a4f5f..187a179 100644 --- a/src/team_lib/mhba/xccl_mhba_collective.c +++ b/src/team_lib/mhba/xccl_mhba_collective.c @@ -6,7 +6,7 @@ #include #include "utils/utils.h" -#define TMP_TRANSPOSE_PREALLOC 256 +#define TMP_TRANSPOSE_PREALLOC 256 //todo check size xccl_status_t xccl_mhba_collective_init_base(xccl_coll_op_args_t *coll_args, xccl_mhba_coll_req_t **request, @@ -68,22 +68,24 @@ static xccl_status_t xccl_mhba_reg_fanin_start(xccl_coll_task_t *task) xccl_mhba_debug("register memory buffers"); - request->send_bf_mr = ibv_reg_mr( - team->node.shared_pd, (void *)request->args.buffer_info.src_buffer, - request->args.buffer_info.len * team->size, sr_mem_access_flags); - if (!request->send_bf_mr) { + ucs_rcache_region_t* send_ptr; + ucs_rcache_region_t* recv_ptr; + if(UCS_OK != ucs_rcache_get(team->context->rcache, (void *)request->args.buffer_info.src_buffer, + request->args.buffer_info.len * team->size, + PROT_READ,&sr_mem_access_flags, &send_ptr)) { xccl_mhba_error("Failed to register send_bf memory (errno=%d)", errno); return XCCL_ERR_NO_RESOURCE; } - request->receive_bf_mr = ibv_reg_mr( - team->node.shared_pd, (void *)request->args.buffer_info.dst_buffer, - request->args.buffer_info.len * team->size, dr_mem_access_flags); - if (!request->receive_bf_mr) { - xccl_mhba_error("Failed to register receive_bf memory (errno=%d)", - errno); - ibv_dereg_mr(request->send_bf_mr); + request->send_rcache_region_p = xccl_rcache_ucs_get_reg_data(send_ptr); + + if(UCS_OK != ucs_rcache_get(team->context->rcache, (void *)request->args.buffer_info.dst_buffer, + request->args.buffer_info.len * team->size, + PROT_WRITE,&dr_mem_access_flags,&recv_ptr)) { + xccl_mhba_error("Failed to register receive_bf memory"); + ucs_rcache_region_put(team->context->rcache,request->send_rcache_region_p->region); return XCCL_ERR_NO_RESOURCE; } + request->recv_rcache_region_p = xccl_rcache_ucs_get_reg_data(recv_ptr); xccl_mhba_debug("fanin start"); /* start task if completion event received */ @@ -247,7 +249,7 @@ static inline xccl_status_t send_block_data(struct ibv_qp *qp, static inline xccl_status_t send_atomic(struct ibv_qp *qp, uint64_t remote_addr, uint32_t rkey, xccl_mhba_team_t *team, - xccl_mhba_coll_req_t *request) + xccl_mhba_coll_req_t *request, int flags) { struct ibv_send_wr *bad_wr; struct ibv_sge list = { @@ -261,7 +263,7 @@ static inline xccl_status_t send_atomic(struct ibv_qp *qp, uint64_t remote_addr, .sg_list = &list, .num_sge = 1, .opcode = IBV_WR_ATOMIC_FETCH_AND_ADD, - .send_flags = IBV_SEND_SIGNALED, + .send_flags = flags, .wr.atomic.remote_addr = remote_addr, .wr.atomic.rkey = rkey, .wr.atomic.compare_add = 1ULL, @@ -311,10 +313,11 @@ static inline xccl_status_t prepost_dummy_recv(struct ibv_qp *qp, int num) return XCCL_OK; } +//todo check cq's in case of parallel operations // add polling mechanism for blocks in order to maintain const qp tx rx static xccl_status_t xccl_mhba_send_blocks_start_with_transpose(xccl_coll_task_t *task) -{ +{ //todo make non-blocking xccl_mhba_task_t *self = ucs_derived_of(task, xccl_mhba_task_t); xccl_mhba_coll_req_t *request = self->req; xccl_mhba_team_t *team = request->team; @@ -384,16 +387,13 @@ xccl_mhba_send_blocks_start_with_transpose(xccl_coll_task_t *task) xccl_mhba_error("Failed sending block [%d,%d,%d]", i, j, k); return status; } - while (!ibv_poll_cq(team->net.cq, 1, transpose_completion)) {} + while (!ibv_poll_cq(team->net.cqs[i], 1, transpose_completion)) {} } } - } - - for (i = 0; i < net_size; i++) { status = send_atomic(team->net.qps[i], (uintptr_t)team->net.remote_ctrl[i].addr + - (index * MHBA_CTRL_SIZE), - team->net.remote_ctrl[i].rkey, team, request); + (index * MHBA_CTRL_SIZE), + team->net.remote_ctrl[i].rkey, team, request,0); if (status != XCCL_OK) { xccl_mhba_error("Failed sending atomic to node [%d]", i); return status; @@ -448,7 +448,7 @@ static xccl_status_t xccl_mhba_send_blocks_start(xccl_coll_task_t *task) status = send_atomic(team->net.qps[i], (uintptr_t)team->net.remote_ctrl[i].addr + (index * MHBA_CTRL_SIZE), - team->net.remote_ctrl[i].rkey, team, request); + team->net.remote_ctrl[i].rkey, team, request,IBV_SEND_SIGNALED); if (status != XCCL_OK) { xccl_mhba_error("Failed sending atomic to node [%d]", i); return status; @@ -458,13 +458,18 @@ static xccl_status_t xccl_mhba_send_blocks_start(xccl_coll_task_t *task) return XCCL_OK; } +xccl_status_t xccl_mhba_send_blocks_progress_transpose(xccl_coll_task_t *task){ + task->state = XCCL_TASK_STATE_COMPLETED; + return XCCL_OK; +} + xccl_status_t xccl_mhba_send_blocks_progress(xccl_coll_task_t *task) { xccl_mhba_task_t *self = ucs_derived_of(task, xccl_mhba_task_t); xccl_mhba_coll_req_t *request = self->req; xccl_mhba_team_t *team = request->team; int i, completions_num; - completions_num = ibv_poll_cq(team->net.cq, team->net.sbgp->group_size, + completions_num = ibv_poll_cq(team->net.cqs[0], team->net.sbgp->group_size, team->work_completion); if (completions_num < 0) { xccl_mhba_error("ibv_poll_cq() failed for RDMA_ATOMIC execution"); @@ -506,11 +511,13 @@ xccl_status_t xccl_mhba_alltoall_init(xccl_coll_op_args_t *coll_args, } xccl_schedule_init(&request->schedule, team->super.ctx); if (team->transpose_hw_limitations) { - block_size = team->blocks_sizes[__ucs_ilog2_u32(len - 1)]; + block_size = (len == 1) ? team->blocks_sizes[__ucs_ilog2_u32(len - 1)+1] : team->blocks_sizes[0]; } else { block_size = team->node.sbgp->group_size; } + block_size = team->requested_block_size ? team->requested_block_size : block_size; + //todo following section correct assuming homogenous PPN across all nodes if (team->node.sbgp->group_size % block_size != 0) { if (team->node.sbgp->group_rank == team->node.asr_rank) { @@ -536,6 +543,10 @@ xccl_status_t xccl_mhba_alltoall_init(xccl_coll_op_args_t *coll_args, request->tmp_transpose_buf = NULL; request->tasks = (xccl_mhba_task_t *)malloc(sizeof(xccl_mhba_task_t) * n_tasks); + if (!request->tasks){ + xccl_mhba_error("malloc tasks failed"); + return XCCL_ERR_NO_MEMORY; + } request->seq_num = team->sequence_number; xccl_mhba_debug("Seq num is %d", request->seq_num); team->sequence_number++; @@ -571,11 +582,12 @@ xccl_status_t xccl_mhba_alltoall_init(xccl_coll_op_args_t *coll_args, if (team->transpose) { request->tasks[2].super.handlers[XCCL_EVENT_COMPLETED] = xccl_mhba_send_blocks_start_with_transpose; + request->tasks[2].super.progress = xccl_mhba_send_blocks_progress_transpose; } else { request->tasks[2].super.handlers[XCCL_EVENT_COMPLETED] = xccl_mhba_send_blocks_start; + request->tasks[2].super.progress = xccl_mhba_send_blocks_progress; } - request->tasks[2].super.progress = xccl_mhba_send_blocks_progress; request->tasks[3].super.handlers[XCCL_EVENT_COMPLETED] = xccl_mhba_fanout_start; diff --git a/src/team_lib/mhba/xccl_mhba_collective.h b/src/team_lib/mhba/xccl_mhba_collective.h index 38f6ec8..6ba4ece 100644 --- a/src/team_lib/mhba/xccl_mhba_collective.h +++ b/src/team_lib/mhba/xccl_mhba_collective.h @@ -21,12 +21,12 @@ typedef struct xccl_mhba_coll_req { xccl_mhba_task_t *tasks; xccl_coll_op_args_t args; xccl_mhba_team_t *team; - int seq_num; - struct ibv_mr *send_bf_mr; - struct ibv_mr *receive_bf_mr; + uint64_t seq_num; xccl_tl_coll_req_t *barrier_req; int block_size; int started; + xccl_mhba_reg_t *send_rcache_region_p; + xccl_mhba_reg_t *recv_rcache_region_p; struct ibv_mr *transpose_buf_mr; void *tmp_transpose_buf; } xccl_mhba_coll_req_t; diff --git a/src/team_lib/mhba/xccl_mhba_lib.c b/src/team_lib/mhba/xccl_mhba_lib.c index fb005be..b42d8fd 100644 --- a/src/team_lib/mhba/xccl_mhba_lib.c +++ b/src/team_lib/mhba/xccl_mhba_lib.c @@ -28,10 +28,10 @@ static ucs_config_field_t xccl_tl_mhba_context_config_table[] = { ucs_offsetof(xccl_tl_mhba_context_config_t, transpose), UCS_CONFIG_TYPE_UINT}, - {"TRANSPOSE_HW_LIMITATIONS", "1", + {"TRANSPOSE_HW_LIMITATIONS", "0", "Boolean - with transpose hw limitations or not", ucs_offsetof(xccl_tl_mhba_context_config_t, transpose_hw_limitations), - UCS_CONFIG_TYPE_UINT}, + UCS_CONFIG_TYPE_UINT}, //todo change to 1 in production {"IB_GLOBAL", "0", "Use global ib routing", ucs_offsetof(xccl_tl_mhba_context_config_t, ib_global), @@ -41,6 +41,10 @@ static ucs_config_field_t xccl_tl_mhba_context_config_table[] = { ucs_offsetof(xccl_tl_mhba_context_config_t, transpose_buf_size), UCS_CONFIG_TYPE_MEMUNITS}, + {"BLOCK_SIZE", "0", "Size of the blocks that are sent using blocked AlltoAll Algorithm", + ucs_offsetof(xccl_tl_mhba_context_config_t, block_size), + UCS_CONFIG_TYPE_UINT}, + {NULL} }; @@ -187,17 +191,10 @@ static xccl_status_t xccl_mhba_collective_finalize(xccl_tl_coll_req_t *request) xccl_status_t status = XCCL_OK; xccl_mhba_coll_req_t *req = ucs_derived_of(request, xccl_mhba_coll_req_t); xccl_mhba_team_t * team = req->team; - if (ibv_dereg_mr(req->send_bf_mr)) { - xccl_mhba_error("Failed to dereg_mr send buffer (errno=%d)", errno); - status = XCCL_ERR_NO_MESSAGE; - } - if (ibv_dereg_mr(req->receive_bf_mr)) { - xccl_mhba_error("Failed to dereg_mr send buffer (errno=%d)", errno); - status = XCCL_ERR_NO_MESSAGE; - } + ucs_rcache_region_put(team->context->rcache,req->send_rcache_region_p->region); + ucs_rcache_region_put(team->context->rcache,req->recv_rcache_region_p->region); if (team->transpose) { - if (req->tmp_transpose_buf) - free(req->tmp_transpose_buf); + free(req->tmp_transpose_buf); if (req->transpose_buf_mr != team->transpose_buf_mr) { ibv_dereg_mr(req->transpose_buf_mr); free(req->transpose_buf_mr->addr); diff --git a/src/team_lib/mhba/xccl_mhba_lib.h b/src/team_lib/mhba/xccl_mhba_lib.h index e327ecd..6b442fd 100644 --- a/src/team_lib/mhba/xccl_mhba_lib.h +++ b/src/team_lib/mhba/xccl_mhba_lib.h @@ -11,6 +11,7 @@ #include "topo/xccl_topo.h" #include #include +#include #define MAX_OUTSTANDING_OPS 1 //todo change - according to limitations (52 top) #define SEQ_INDEX(_seq_num) ((_seq_num) % MAX_OUTSTANDING_OPS) @@ -27,10 +28,9 @@ typedef struct xccl_tl_mhba_context_config { int transpose; int transpose_hw_limitations; size_t transpose_buf_size; + int block_size; } xccl_tl_mhba_context_config_t; -//todo add block_size config - typedef struct xccl_team_lib_mhba { xccl_team_lib_t super; xccl_team_lib_mhba_config_t config; @@ -73,6 +73,7 @@ typedef struct xccl_mhba_context { struct ibv_context *ib_ctx; struct ibv_pd *ib_pd; int ib_port; + ucs_rcache_t *rcache; } xccl_mhba_context_t; typedef struct xccl_mhba_op { @@ -103,19 +104,28 @@ typedef struct xccl_mhba_node { struct mlx5dv_qp_ex *umr_mlx5dv_qp_ex; } xccl_mhba_node_t; -#define MHBA_CTRL_SIZE 128 //todo change according to arch +#define MHBA_CTRL_SIZE 128 //todo change to UCS_ARCH_CACHE_LINE_SIZE #define MHBA_DATA_SIZE sizeof(struct mlx5dv_mr_interleaved) -#define MHBA_NUM_OF_BLOCKS_SIZE_BINS 7 +#define MHBA_NUM_OF_BLOCKS_SIZE_BINS 8 #define MAX_TRANSPOSE_SIZE 8000 // HW transpose unit is limited to matrix size #define MAX_MSG_SIZE 128 // HW transpose unit is limited to element size #define MAX_STRIDED_ENTRIES 55 // from limit of NIC memory - Sergey Gorenko's email +typedef struct xccl_mhba_reg { + struct ibv_mr *mr; + ucs_rcache_region_t *region; +} xccl_mhba_reg_t; + +static inline xccl_mhba_reg_t* xccl_rcache_ucs_get_reg_data(ucs_rcache_region_t *region) { + return (xccl_mhba_reg_t *)((ptrdiff_t)region + sizeof(ucs_rcache_region_t)); +} + typedef struct xccl_mhba_net { xccl_sbgp_t *sbgp; int net_size; int *rank_map; struct ibv_qp **qps; - struct ibv_cq *cq; + struct ibv_cq **cqs; struct ibv_mr *ctrl_mr; struct { void *addr; @@ -132,13 +142,14 @@ typedef struct xccl_mhba_team { uint64_t max_msg_size; xccl_mhba_node_t node; xccl_mhba_net_t net; - int sequence_number; + uint64_t sequence_number; int op_busy[MAX_OUTSTANDING_OPS]; int cq_completions[MAX_OUTSTANDING_OPS]; xccl_mhba_context_t *context; int blocks_sizes[MHBA_NUM_OF_BLOCKS_SIZE_BINS]; int size; uint64_t dummy_atomic_buff; + int requested_block_size; struct ibv_mr *dummy_bf_mr; struct ibv_wc *work_completion; void *transpose_buf; diff --git a/src/team_lib/mhba/xccl_mhba_mkeys.c b/src/team_lib/mhba/xccl_mhba_mkeys.c index 12eb40c..d9f0f44 100644 --- a/src/team_lib/mhba/xccl_mhba_mkeys.c +++ b/src/team_lib/mhba/xccl_mhba_mkeys.c @@ -308,7 +308,7 @@ static void update_mkey_entry(xccl_mhba_node_t *node, xccl_mhba_coll_req_t *req, (struct mlx5dv_mr_interleaved *)(direction_send ? node->ops[index].my_send_umr_data : node->ops[index].my_recv_umr_data); - struct ibv_mr *buff = direction_send ? req->send_bf_mr : req->receive_bf_mr; + struct ibv_mr *buff = direction_send ? req->send_rcache_region_p->mr : req->recv_rcache_region_p->mr; mkey_entry->addr = (uintptr_t)buff->addr; mkey_entry->bytes_count = req->block_size * req->args.buffer_info.len; mkey_entry->bytes_skip = 0; diff --git a/src/team_lib/mhba/xccl_mhba_team.c b/src/team_lib/mhba/xccl_mhba_team.c index 0b6cba7..1910e94 100644 --- a/src/team_lib/mhba/xccl_mhba_team.c +++ b/src/team_lib/mhba/xccl_mhba_team.c @@ -6,6 +6,7 @@ #include "core/xccl_team.h" #include "xccl_mhba_ib.h" #include +#include typedef struct bcast_data { int shmid; @@ -46,13 +47,13 @@ static void calc_block_size(xccl_mhba_team_t *team) { int i; int block_size = team->node.sbgp->group_size; - int msg_len = MAX_MSG_SIZE; - for (i = MHBA_NUM_OF_BLOCKS_SIZE_BINS - 1; i >= 0; i--) { + int msg_len = 1; + for (i = 0; i < MHBA_NUM_OF_BLOCKS_SIZE_BINS; i++) { while ((block_size * block_size) * msg_len > MAX_TRANSPOSE_SIZE) { block_size -= 1; } team->blocks_sizes[i] = block_size; - msg_len >> 1; + msg_len << 1; } } @@ -87,6 +88,56 @@ static void build_rank_map(xccl_mhba_team_t *mhba_team) free(data); } +static ucs_status_t rcache_reg_mr(void *context, ucs_rcache_t *rcache,void *arg, ucs_rcache_region_t *rregion, + uint16_t flags){ + xccl_mhba_team_t *team = (xccl_mhba_team_t*)context; + void *addr = (void*)rregion->super.start; + size_t length = (size_t)(rregion->super.end - rregion->super.start); + xccl_mhba_reg_t* mhba_reg = xccl_rcache_ucs_get_reg_data(rregion); + mhba_reg->region = rregion; + int* mem_flags = (int*) arg; + mhba_reg->mr = ibv_reg_mr(team->node.shared_pd, addr, length, *mem_flags); + if (!mhba_reg->mr) { + xccl_mhba_error("Failed to register memory"); + return UCS_ERR_NO_MESSAGE; + } + return UCS_OK; +} + +static void rcache_dereg_mr(void *context, ucs_rcache_t *rcache, ucs_rcache_region_t *rregion) { + xccl_mhba_reg_t* mhba_reg = xccl_rcache_ucs_get_reg_data(rregion); + assert(mhba_reg->region == rregion); + ibv_dereg_mr(mhba_reg->mr); + mhba_reg->mr = NULL; +} + +static xccl_status_t create_rcache(xccl_mhba_team_t* mhba_team) { + static ucs_rcache_ops_t rcache_ucs_ops = { + .mem_reg = rcache_reg_mr, + .mem_dereg = rcache_dereg_mr, + .dump_region = NULL + }; + + ucs_rcache_params_t rcache_params; + rcache_params.region_struct_size = sizeof(ucs_rcache_region_t)+sizeof(xccl_mhba_reg_t); + rcache_params.alignment = UCS_PGT_ADDR_ALIGN; + rcache_params.max_alignment = ucs_get_page_size(); + rcache_params.ucm_events = UCM_EVENT_VM_UNMAPPED | + UCM_EVENT_MEM_TYPE_FREE; + rcache_params.ucm_event_priority = 1000; + rcache_params.context = (void*)mhba_team; //todo maybe change, needed for shared_pd + rcache_params.ops = &rcache_ucs_ops; + + ucs_status_t status = ucs_rcache_create(&rcache_params, "reg cache", + ucs_stats_get_root(), &mhba_team->context->rcache); + + if (status != UCS_OK) { + xccl_mhba_error("Failed to create reg cache"); + return XCCL_ERR_NO_MESSAGE; + } + return XCCL_OK; +} + xccl_status_t xccl_mhba_team_create_post(xccl_tl_context_t *context, xccl_team_params_t *params, xccl_team_t *base_team, @@ -118,6 +169,10 @@ xccl_status_t xccl_mhba_team_create_post(xccl_tl_context_t *context, memset(mhba_team->op_busy, 0, MAX_OUTSTANDING_OPS * sizeof(int)); + if(XCCL_OK != create_rcache(mhba_team)){ + goto fail; + } + node = xccl_team_topo_get_sbgp(base_team->topo, XCCL_SBGP_NODE); if (node->group_size > MAX_STRIDED_ENTRIES) { xccl_mhba_error("PPN too large"); @@ -206,9 +261,10 @@ xccl_status_t xccl_mhba_team_create_post(xccl_tl_context_t *context, xccl_sbgp_oob_barrier(node, params->oob); calc_block_size(mhba_team); + mhba_team->requested_block_size = ctx->cfg.block_size; if (mhba_team->node.asr_rank == node->group_rank) { if (mhba_team->transpose) { - mhba_team->transpose_buf = malloc(ctx->cfg.transpose_buf_size); + mhba_team->transpose_buf = malloc(ctx->cfg.transpose_buf_size); //todo malloc per operation for parallel if (!mhba_team->transpose_buf) { goto fail_after_shmat; } @@ -221,20 +277,11 @@ xccl_status_t xccl_mhba_team_create_post(xccl_tl_context_t *context, status = xccl_mhba_init_umr(ctx, &mhba_team->node); if (status != XCCL_OK) { xccl_mhba_error("Failed to init UMR"); - goto fail_after_shmat; - } - asr_cq_size = net_size * MAX_OUTSTANDING_OPS; - mhba_team->net.cq = ibv_create_cq(mhba_team->node.shared_ctx, - asr_cq_size, NULL, NULL, 0); - if (!mhba_team->net.cq) { - xccl_mhba_error("failed to allocate ASR CQ"); - goto fail_after_shmat; + goto fail_after_transpose_reg; } memset(&qp_init_attr, 0, sizeof(qp_init_attr)); //todo change in case of non-homogenous ppn - qp_init_attr.send_cq = mhba_team->net.cq; - qp_init_attr.recv_cq = mhba_team->net.cq; qp_init_attr.cap.max_send_wr = (SQUARED(node_size / 2) + 1) * MAX_OUTSTANDING_OPS; // TODO switch back to fixed tx/rx qp_init_attr.cap.max_recv_wr = @@ -247,15 +294,21 @@ xccl_status_t xccl_mhba_team_create_post(xccl_tl_context_t *context, mhba_team->net.qps = malloc(sizeof(struct ibv_qp *) * net_size); if (!mhba_team->net.qps) { xccl_mhba_error("failed to allocate asr qps array"); - goto fail_after_cq; + goto fail_after_transpose_reg; + } + mhba_team->net.cqs = malloc(sizeof(struct ibv_cq *) * (mhba_team->transpose ? net_size : 1)); + if (!mhba_team->net.cqs) { + xccl_mhba_error("failed to allocate asr cqs array"); + goto fail_after_qp_alloc; } + // for each ASR - qp num, in addition to port lid, ctrl segment rkey and address, recieve mkey rkey local_data_size = (net_size * sizeof(uint32_t)) + sizeof(uint32_t) + 2 * sizeof(uint32_t) + sizeof(void *); local_data = malloc(local_data_size); if (!local_data) { xccl_mhba_error("failed to allocate local data"); - goto local_data_fail; + goto fail_after_cq_alloc; } global_data = malloc(local_data_size * net_size); if (!global_data) { @@ -264,12 +317,24 @@ xccl_status_t xccl_mhba_team_create_post(xccl_tl_context_t *context, } for (i = 0; i < net_size; i++) { + if(i == 0 || mhba_team->transpose){ + mhba_team->net.cqs[i] = ibv_create_cq(mhba_team->node.shared_ctx, mhba_team->transpose ? + MAX_OUTSTANDING_OPS : net_size * MAX_OUTSTANDING_OPS, NULL, NULL, 0); + if (!mhba_team->net.cqs[i]) { + xccl_mhba_error("failed to create cq for dest %d, errno %d", i, + errno); + goto cq_qp_creation_fail; + } + qp_init_attr.send_cq = mhba_team->net.cqs[i]; + qp_init_attr.recv_cq = mhba_team->net.cqs[i]; + } + mhba_team->net.qps[i] = ibv_create_qp(mhba_team->node.shared_pd, &qp_init_attr); if (!mhba_team->net.qps[i]) { xccl_mhba_error("failed to create qp for dest %d, errno %d", i, errno); - goto ctrl_fail; + goto cq_qp_creation_fail; } local_data[i] = mhba_team->net.qps[i]->qp_num; } @@ -281,7 +346,7 @@ xccl_status_t xccl_mhba_team_create_post(xccl_tl_context_t *context, IBV_ACCESS_REMOTE_ATOMIC | IBV_ACCESS_LOCAL_WRITE); if (!mhba_team->net.ctrl_mr) { xccl_mhba_error("failed to register control data, errno %d", errno); - goto ctrl_fail; + goto cq_qp_creation_fail; } ibv_query_port(ctx->ib_ctx, ctx->ib_port, &port_attr); local_data[net_size] = port_attr.lid; @@ -380,16 +445,21 @@ xccl_status_t xccl_mhba_team_create_post(xccl_tl_context_t *context, ibv_dereg_mr(mhba_team->dummy_bf_mr); remote_ctrl_fail: ibv_dereg_mr(mhba_team->net.ctrl_mr); -ctrl_fail: +cq_qp_creation_fail: free(global_data); + for (i = 0; i < net_size; i++){ + ibv_destroy_cq(mhba_team->net.cqs[i]); + ibv_destroy_qp(mhba_team->net.qps[i]); + } global_data_fail: free(local_data); -local_data_fail: +fail_after_cq_alloc: + free(mhba_team->net.cqs); +fail_after_qp_alloc: free(mhba_team->net.qps); -fail_after_cq: - if (ibv_destroy_cq(mhba_team->net.cq)) { - xccl_mhba_error("net cq destroy failed (errno=%d)", errno); - } +fail_after_transpose_reg: + ibv_dereg_mr(mhba_team->transpose_buf_mr); + free(mhba_team->transpose_buf); fail_after_shmat: if (-1 == shmdt(mhba_team->node.storage)) { xccl_mhba_error("failed to shmdt %p, errno %d", mhba_team->node.storage, @@ -416,6 +486,7 @@ xccl_status_t xccl_mhba_team_destroy(xccl_tl_team_t *team) xccl_mhba_team_t *mhba_team = ucs_derived_of(team, xccl_mhba_team_t); int i; xccl_mhba_debug("destroying team %p", team); + ucs_rcache_destroy(mhba_team->context->rcache); if (-1 == shmdt(mhba_team->node.storage)) { xccl_mhba_error("failed to shmdt %p, errno %d", mhba_team->node.storage, errno); @@ -435,9 +506,12 @@ xccl_status_t xccl_mhba_team_destroy(xccl_tl_team_t *team) ibv_destroy_qp(mhba_team->net.qps[i]); } free(mhba_team->net.qps); - if (ibv_destroy_cq(mhba_team->net.cq)) { - xccl_mhba_error("net cq destroy failed (errno=%d)", errno); + for (i = 0; i < (mhba_team->transpose ? mhba_team->net.sbgp->group_size : 1); i++) { + if (ibv_destroy_cq(mhba_team->net.cqs[i])) { + xccl_mhba_error("net cq destroy failed (errno=%d)", errno); + } } + free(mhba_team->net.cqs); mhba_team->net.ucx_team->ctx->lib->team_destroy( mhba_team->net.ucx_team); @@ -449,7 +523,7 @@ xccl_status_t xccl_mhba_team_destroy(xccl_tl_team_t *team) ibv_dereg_mr(mhba_team->dummy_bf_mr); free(mhba_team->work_completion); free(mhba_team->net.rank_map); - if (mhba_team->transpose_buf_mr) { + if (mhba_team->transpose) { ibv_dereg_mr(mhba_team->transpose_buf_mr); free(mhba_team->transpose_buf); }