From 61120c320a2ffa65a39d0fc898aa3b54c79118dd Mon Sep 17 00:00:00 2001 From: Joshua Hursey Date: Tue, 1 Feb 2022 14:44:19 -0600 Subject: [PATCH] Fix base reduce_scatter_block for large payloads * Update the four base reduce_scatter_block algorithms to support large payload collectives. - The recursive doubling collective fix would have required changing some ompi_datatype functions which was more extensive than I wanted to go after in this commit. So if a large payload is expected in that collective then it falls back to the linear algorithm. Signed-off-by: Joshua Hursey --- .../base/coll_base_reduce_scatter_block.c | 103 ++++++++++++------ 1 file changed, 70 insertions(+), 33 deletions(-) diff --git a/ompi/mca/coll/base/coll_base_reduce_scatter_block.c b/ompi/mca/coll/base/coll_base_reduce_scatter_block.c index 6dd83daad4f..24b39d97676 100644 --- a/ompi/mca/coll/base/coll_base_reduce_scatter_block.c +++ b/ompi/mca/coll/base/coll_base_reduce_scatter_block.c @@ -17,6 +17,7 @@ * and Technology (RIST). All rights reserved. * Copyright (c) 2018 Siberian State University of Telecommunications * and Information Sciences. All rights reserved. + * Copyright (c) 2022 IBM Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -58,7 +59,8 @@ ompi_coll_base_reduce_scatter_block_basic_linear(const void *sbuf, void *rbuf, i struct ompi_communicator_t *comm, mca_coll_base_module_t *module) { - int rank, size, count, err = OMPI_SUCCESS; + int rank, size, err = OMPI_SUCCESS; + size_t count; ptrdiff_t gap, span; char *recv_buf = NULL, *recv_buf_free = NULL; @@ -67,7 +69,7 @@ ompi_coll_base_reduce_scatter_block_basic_linear(const void *sbuf, void *rbuf, i size = ompi_comm_size(comm); /* short cut the trivial case */ - count = rcount * size; + count = rcount * (size_t)size; if (0 == count) { return OMPI_SUCCESS; } @@ -91,17 +93,38 @@ ompi_coll_base_reduce_scatter_block_basic_linear(const void *sbuf, void *rbuf, i recv_buf = recv_buf_free - gap; } - /* reduction */ - err = - comm->c_coll->coll_reduce(sbuf, recv_buf, count, dtype, op, 0, - comm, comm->c_coll->coll_reduce_module); + if ( OPAL_UNLIKELY(count > INT_MAX) ) { + // Sending the message in the coll_reduce as "rcount*size" would exceed + // the 'int count' parameter in the coll_reduce() function. Instead reduce + // the result in "rcount" chunks. + int i; + void *rbuf_ptr, *sbuf_ptr; + span = opal_datatype_span(&dtype->super, rcount, &gap); + for( i = 0; i < size; ++i ) { + rbuf_ptr = (char*)recv_buf + span * (size_t)i; + sbuf_ptr = (char*)sbuf + span * (size_t)i; + /* reduction */ + err = + comm->c_coll->coll_reduce(sbuf_ptr, rbuf_ptr, rcount, dtype, op, 0, + comm, comm->c_coll->coll_reduce_module); + if (MPI_SUCCESS != err) { + goto cleanup; + } + } + } else { + /* reduction */ + err = + comm->c_coll->coll_reduce(sbuf, recv_buf, (int)count, dtype, op, 0, + comm, comm->c_coll->coll_reduce_module); + if (MPI_SUCCESS != err) { + goto cleanup; + } + } /* scatter */ - if (MPI_SUCCESS == err) { - err = comm->c_coll->coll_scatter(recv_buf, rcount, dtype, - rbuf, rcount, dtype, 0, - comm, comm->c_coll->coll_scatter_module); - } + err = comm->c_coll->coll_scatter(recv_buf, rcount, dtype, + rbuf, rcount, dtype, 0, + comm, comm->c_coll->coll_scatter_module); cleanup: if (NULL != recv_buf_free) free(recv_buf_free); @@ -146,7 +169,16 @@ ompi_coll_base_reduce_scatter_block_intra_recursivedoubling( if (comm_size < 2) return MPI_SUCCESS; - totalcount = comm_size * rcount; + totalcount = comm_size * (size_t)rcount; + if( OPAL_UNLIKELY(totalcount > INT_MAX) ) { + /* + * Large payload collectives are not supported by this algorithm. + * The blocklens and displs calculations in the loop below + * will overflow an int data type. + * Fallback to the linear algorithm. + */ + return ompi_coll_base_reduce_scatter_block_basic_linear(sbuf, rbuf, rcount, dtype, op, comm, module); + } ompi_datatype_type_extent(dtype, &extent); span = opal_datatype_span(&dtype->super, totalcount, &gap); tmpbuf_raw = malloc(span); @@ -347,7 +379,8 @@ ompi_coll_base_reduce_scatter_block_intra_recursivehalving( return ompi_coll_base_reduce_scatter_block_basic_linear(sbuf, rbuf, rcount, dtype, op, comm, module); } - totalcount = comm_size * rcount; + + totalcount = comm_size * (size_t)rcount; ompi_datatype_type_extent(dtype, &extent); span = opal_datatype_span(&dtype->super, totalcount, &gap); tmpbuf_raw = malloc(span); @@ -431,22 +464,22 @@ ompi_coll_base_reduce_scatter_block_intra_recursivehalving( * have their result calculated by the process to their * right (rank + 1). */ - int send_count = 0, recv_count = 0; + size_t send_count = 0, recv_count = 0; if (vrank < vpeer) { /* Send the right half of the buffer, recv the left half */ send_index = recv_index + mask; - send_count = rcount * ompi_range_sum(send_index, last_index - 1, nprocs_rem - 1); - recv_count = rcount * ompi_range_sum(recv_index, send_index - 1, nprocs_rem - 1); + send_count = rcount * (size_t)ompi_range_sum(send_index, last_index - 1, nprocs_rem - 1); + recv_count = rcount * (size_t)ompi_range_sum(recv_index, send_index - 1, nprocs_rem - 1); } else { /* Send the left half of the buffer, recv the right half */ recv_index = send_index + mask; - send_count = rcount * ompi_range_sum(send_index, recv_index - 1, nprocs_rem - 1); - recv_count = rcount * ompi_range_sum(recv_index, last_index - 1, nprocs_rem - 1); + send_count = rcount * (size_t)ompi_range_sum(send_index, recv_index - 1, nprocs_rem - 1); + recv_count = rcount * (size_t)ompi_range_sum(recv_index, last_index - 1, nprocs_rem - 1); } - ptrdiff_t rdispl = rcount * ((recv_index <= nprocs_rem - 1) ? - 2 * recv_index : nprocs_rem + recv_index); - ptrdiff_t sdispl = rcount * ((send_index <= nprocs_rem - 1) ? - 2 * send_index : nprocs_rem + send_index); + ptrdiff_t rdispl = rcount * (size_t)((recv_index <= nprocs_rem - 1) ? + 2 * recv_index : nprocs_rem + recv_index); + ptrdiff_t sdispl = rcount * (size_t)((send_index <= nprocs_rem - 1) ? + 2 * send_index : nprocs_rem + send_index); struct ompi_request_t *request = NULL; if (recv_count > 0) { @@ -587,7 +620,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly( sbuf, rbuf, rcount, dtype, op, comm, module); } - totalcount = comm_size * rcount; + totalcount = comm_size * (size_t)rcount; ompi_datatype_type_extent(dtype, &extent); span = opal_datatype_span(&dtype->super, totalcount, &gap); tmpbuf[0] = malloc(span); @@ -677,13 +710,17 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly( /* Send the upper half of reduction buffer, recv the lower half */ recv_index += nblocks; } - int send_count = rcount * ompi_range_sum(send_index, - send_index + nblocks - 1, nprocs_rem - 1); - int recv_count = rcount * ompi_range_sum(recv_index, - recv_index + nblocks - 1, nprocs_rem - 1); - ptrdiff_t sdispl = rcount * ((send_index <= nprocs_rem - 1) ? + size_t send_count = rcount * + (size_t)ompi_range_sum(send_index, + send_index + nblocks - 1, + nprocs_rem - 1); + size_t recv_count = rcount * + (size_t)ompi_range_sum(recv_index, + recv_index + nblocks - 1, + nprocs_rem - 1); + ptrdiff_t sdispl = rcount * (size_t)((send_index <= nprocs_rem - 1) ? 2 * send_index : nprocs_rem + send_index); - ptrdiff_t rdispl = rcount * ((recv_index <= nprocs_rem - 1) ? + ptrdiff_t rdispl = rcount * (size_t)((recv_index <= nprocs_rem - 1) ? 2 * recv_index : nprocs_rem + recv_index); err = ompi_coll_base_sendrecv(psend + (ptrdiff_t)sdispl * extent, send_count, @@ -719,7 +756,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly( * Process has two blocks: for excluded process and own. * Send result to the excluded process. */ - ptrdiff_t sdispl = rcount * ((send_index <= nprocs_rem - 1) ? + ptrdiff_t sdispl = rcount * (size_t)((send_index <= nprocs_rem - 1) ? 2 * send_index : nprocs_rem + send_index); err = MCA_PML_CALL(send(psend + (ptrdiff_t)sdispl * extent, rcount, dtype, peer - 1, @@ -729,7 +766,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly( } /* Send result to a remote process according to a mirror permutation */ - ptrdiff_t sdispl = rcount * ((send_index <= nprocs_rem - 1) ? + ptrdiff_t sdispl = rcount * (size_t)((send_index <= nprocs_rem - 1) ? 2 * send_index : nprocs_rem + send_index); /* If process has two blocks, then send the second block (own block) */ if (vpeer < nprocs_rem) @@ -821,7 +858,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly_pof2( if (rcount == 0 || comm_size < 2) return MPI_SUCCESS; - totalcount = comm_size * rcount; + totalcount = comm_size * (size_t)rcount; ompi_datatype_type_extent(dtype, &extent); span = opal_datatype_span(&dtype->super, totalcount, &gap); tmpbuf[0] = malloc(span); @@ -843,7 +880,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly_pof2( if (MPI_SUCCESS != err) { goto cleanup_and_return; } } - int nblocks = totalcount, send_index = 0, recv_index = 0; + size_t nblocks = totalcount, send_index = 0, recv_index = 0; for (int mask = 1; mask < comm_size; mask <<= 1) { int peer = rank ^ mask; nblocks /= 2;