Skip to content

Commit

Permalink
Fix base reduce_scatter_block for large payloads
Browse files Browse the repository at this point in the history
 * Update the four base reduce_scatter_block algorithms to support
   large payload collectives.
   - The recursive doubling collective fix would have required
     changing some ompi_datatype functions which was more extensive
     than I wanted to go after in this commit. So if a large payload
     is expected in that collective then it falls back to the linear
     algorithm.

Signed-off-by: Joshua Hursey <jhursey@us.ibm.com>
  • Loading branch information
jjhursey committed Feb 11, 2022
1 parent 34685a2 commit 61120c3
Showing 1 changed file with 70 additions and 33 deletions.
103 changes: 70 additions & 33 deletions ompi/mca/coll/base/coll_base_reduce_scatter_block.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
* and Technology (RIST). All rights reserved.
* Copyright (c) 2018 Siberian State University of Telecommunications
* and Information Sciences. All rights reserved.
* Copyright (c) 2022 IBM Corporation. All rights reserved.
* $COPYRIGHT$
*
* Additional copyrights may follow
Expand Down Expand Up @@ -58,7 +59,8 @@ ompi_coll_base_reduce_scatter_block_basic_linear(const void *sbuf, void *rbuf, i
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module)
{
int rank, size, count, err = OMPI_SUCCESS;
int rank, size, err = OMPI_SUCCESS;
size_t count;
ptrdiff_t gap, span;
char *recv_buf = NULL, *recv_buf_free = NULL;

Expand All @@ -67,7 +69,7 @@ ompi_coll_base_reduce_scatter_block_basic_linear(const void *sbuf, void *rbuf, i
size = ompi_comm_size(comm);

/* short cut the trivial case */
count = rcount * size;
count = rcount * (size_t)size;
if (0 == count) {
return OMPI_SUCCESS;
}
Expand All @@ -91,17 +93,38 @@ ompi_coll_base_reduce_scatter_block_basic_linear(const void *sbuf, void *rbuf, i
recv_buf = recv_buf_free - gap;
}

/* reduction */
err =
comm->c_coll->coll_reduce(sbuf, recv_buf, count, dtype, op, 0,
comm, comm->c_coll->coll_reduce_module);
if ( OPAL_UNLIKELY(count > INT_MAX) ) {
// Sending the message in the coll_reduce as "rcount*size" would exceed
// the 'int count' parameter in the coll_reduce() function. Instead reduce
// the result in "rcount" chunks.
int i;
void *rbuf_ptr, *sbuf_ptr;
span = opal_datatype_span(&dtype->super, rcount, &gap);
for( i = 0; i < size; ++i ) {
rbuf_ptr = (char*)recv_buf + span * (size_t)i;
sbuf_ptr = (char*)sbuf + span * (size_t)i;
/* reduction */
err =
comm->c_coll->coll_reduce(sbuf_ptr, rbuf_ptr, rcount, dtype, op, 0,
comm, comm->c_coll->coll_reduce_module);
if (MPI_SUCCESS != err) {
goto cleanup;
}
}
} else {
/* reduction */
err =
comm->c_coll->coll_reduce(sbuf, recv_buf, (int)count, dtype, op, 0,
comm, comm->c_coll->coll_reduce_module);
if (MPI_SUCCESS != err) {
goto cleanup;
}
}

/* scatter */
if (MPI_SUCCESS == err) {
err = comm->c_coll->coll_scatter(recv_buf, rcount, dtype,
rbuf, rcount, dtype, 0,
comm, comm->c_coll->coll_scatter_module);
}
err = comm->c_coll->coll_scatter(recv_buf, rcount, dtype,
rbuf, rcount, dtype, 0,
comm, comm->c_coll->coll_scatter_module);

cleanup:
if (NULL != recv_buf_free) free(recv_buf_free);
Expand Down Expand Up @@ -146,7 +169,16 @@ ompi_coll_base_reduce_scatter_block_intra_recursivedoubling(
if (comm_size < 2)
return MPI_SUCCESS;

totalcount = comm_size * rcount;
totalcount = comm_size * (size_t)rcount;
if( OPAL_UNLIKELY(totalcount > INT_MAX) ) {
/*
* Large payload collectives are not supported by this algorithm.
* The blocklens and displs calculations in the loop below
* will overflow an int data type.
* Fallback to the linear algorithm.
*/
return ompi_coll_base_reduce_scatter_block_basic_linear(sbuf, rbuf, rcount, dtype, op, comm, module);
}
ompi_datatype_type_extent(dtype, &extent);
span = opal_datatype_span(&dtype->super, totalcount, &gap);
tmpbuf_raw = malloc(span);
Expand Down Expand Up @@ -347,7 +379,8 @@ ompi_coll_base_reduce_scatter_block_intra_recursivehalving(
return ompi_coll_base_reduce_scatter_block_basic_linear(sbuf, rbuf, rcount, dtype,
op, comm, module);
}
totalcount = comm_size * rcount;

totalcount = comm_size * (size_t)rcount;
ompi_datatype_type_extent(dtype, &extent);
span = opal_datatype_span(&dtype->super, totalcount, &gap);
tmpbuf_raw = malloc(span);
Expand Down Expand Up @@ -431,22 +464,22 @@ ompi_coll_base_reduce_scatter_block_intra_recursivehalving(
* have their result calculated by the process to their
* right (rank + 1).
*/
int send_count = 0, recv_count = 0;
size_t send_count = 0, recv_count = 0;
if (vrank < vpeer) {
/* Send the right half of the buffer, recv the left half */
send_index = recv_index + mask;
send_count = rcount * ompi_range_sum(send_index, last_index - 1, nprocs_rem - 1);
recv_count = rcount * ompi_range_sum(recv_index, send_index - 1, nprocs_rem - 1);
send_count = rcount * (size_t)ompi_range_sum(send_index, last_index - 1, nprocs_rem - 1);
recv_count = rcount * (size_t)ompi_range_sum(recv_index, send_index - 1, nprocs_rem - 1);
} else {
/* Send the left half of the buffer, recv the right half */
recv_index = send_index + mask;
send_count = rcount * ompi_range_sum(send_index, recv_index - 1, nprocs_rem - 1);
recv_count = rcount * ompi_range_sum(recv_index, last_index - 1, nprocs_rem - 1);
send_count = rcount * (size_t)ompi_range_sum(send_index, recv_index - 1, nprocs_rem - 1);
recv_count = rcount * (size_t)ompi_range_sum(recv_index, last_index - 1, nprocs_rem - 1);
}
ptrdiff_t rdispl = rcount * ((recv_index <= nprocs_rem - 1) ?
2 * recv_index : nprocs_rem + recv_index);
ptrdiff_t sdispl = rcount * ((send_index <= nprocs_rem - 1) ?
2 * send_index : nprocs_rem + send_index);
ptrdiff_t rdispl = rcount * (size_t)((recv_index <= nprocs_rem - 1) ?
2 * recv_index : nprocs_rem + recv_index);
ptrdiff_t sdispl = rcount * (size_t)((send_index <= nprocs_rem - 1) ?
2 * send_index : nprocs_rem + send_index);
struct ompi_request_t *request = NULL;

if (recv_count > 0) {
Expand Down Expand Up @@ -587,7 +620,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly(
sbuf, rbuf, rcount, dtype, op, comm, module);
}

totalcount = comm_size * rcount;
totalcount = comm_size * (size_t)rcount;
ompi_datatype_type_extent(dtype, &extent);
span = opal_datatype_span(&dtype->super, totalcount, &gap);
tmpbuf[0] = malloc(span);
Expand Down Expand Up @@ -677,13 +710,17 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly(
/* Send the upper half of reduction buffer, recv the lower half */
recv_index += nblocks;
}
int send_count = rcount * ompi_range_sum(send_index,
send_index + nblocks - 1, nprocs_rem - 1);
int recv_count = rcount * ompi_range_sum(recv_index,
recv_index + nblocks - 1, nprocs_rem - 1);
ptrdiff_t sdispl = rcount * ((send_index <= nprocs_rem - 1) ?
size_t send_count = rcount *
(size_t)ompi_range_sum(send_index,
send_index + nblocks - 1,
nprocs_rem - 1);
size_t recv_count = rcount *
(size_t)ompi_range_sum(recv_index,
recv_index + nblocks - 1,
nprocs_rem - 1);
ptrdiff_t sdispl = rcount * (size_t)((send_index <= nprocs_rem - 1) ?
2 * send_index : nprocs_rem + send_index);
ptrdiff_t rdispl = rcount * ((recv_index <= nprocs_rem - 1) ?
ptrdiff_t rdispl = rcount * (size_t)((recv_index <= nprocs_rem - 1) ?
2 * recv_index : nprocs_rem + recv_index);

err = ompi_coll_base_sendrecv(psend + (ptrdiff_t)sdispl * extent, send_count,
Expand Down Expand Up @@ -719,7 +756,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly(
* Process has two blocks: for excluded process and own.
* Send result to the excluded process.
*/
ptrdiff_t sdispl = rcount * ((send_index <= nprocs_rem - 1) ?
ptrdiff_t sdispl = rcount * (size_t)((send_index <= nprocs_rem - 1) ?
2 * send_index : nprocs_rem + send_index);
err = MCA_PML_CALL(send(psend + (ptrdiff_t)sdispl * extent,
rcount, dtype, peer - 1,
Expand All @@ -729,7 +766,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly(
}

/* Send result to a remote process according to a mirror permutation */
ptrdiff_t sdispl = rcount * ((send_index <= nprocs_rem - 1) ?
ptrdiff_t sdispl = rcount * (size_t)((send_index <= nprocs_rem - 1) ?
2 * send_index : nprocs_rem + send_index);
/* If process has two blocks, then send the second block (own block) */
if (vpeer < nprocs_rem)
Expand Down Expand Up @@ -821,7 +858,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly_pof2(
if (rcount == 0 || comm_size < 2)
return MPI_SUCCESS;

totalcount = comm_size * rcount;
totalcount = comm_size * (size_t)rcount;
ompi_datatype_type_extent(dtype, &extent);
span = opal_datatype_span(&dtype->super, totalcount, &gap);
tmpbuf[0] = malloc(span);
Expand All @@ -843,7 +880,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly_pof2(
if (MPI_SUCCESS != err) { goto cleanup_and_return; }
}

int nblocks = totalcount, send_index = 0, recv_index = 0;
size_t nblocks = totalcount, send_index = 0, recv_index = 0;
for (int mask = 1; mask < comm_size; mask <<= 1) {
int peer = rank ^ mask;
nblocks /= 2;
Expand Down

0 comments on commit 61120c3

Please sign in to comment.