Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix base reduce_scatter_block for large payloads #9956

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 125 additions & 43 deletions ompi/mca/coll/base/coll_base_reduce_scatter_block.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
* and Technology (RIST). All rights reserved.
* Copyright (c) 2018 Siberian State University of Telecommunications
* and Information Sciences. All rights reserved.
* Copyright (c) 2022 IBM Corporation. All rights reserved.
* $COPYRIGHT$
*
* Additional copyrights may follow
Expand Down Expand Up @@ -58,7 +59,8 @@ ompi_coll_base_reduce_scatter_block_basic_linear(const void *sbuf, void *rbuf, i
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module)
{
int rank, size, count, err = OMPI_SUCCESS;
int rank, size, err = OMPI_SUCCESS;
size_t count;
ptrdiff_t gap, span;
char *recv_buf = NULL, *recv_buf_free = NULL;

Expand All @@ -67,40 +69,106 @@ ompi_coll_base_reduce_scatter_block_basic_linear(const void *sbuf, void *rbuf, i
size = ompi_comm_size(comm);

/* short cut the trivial case */
count = rcount * size;
count = rcount * (size_t)size;
if (0 == count) {
return OMPI_SUCCESS;
}

/* get datatype information */
span = opal_datatype_span(&dtype->super, count, &gap);

/* Handle MPI_IN_PLACE */
if (MPI_IN_PLACE == sbuf) {
sbuf = rbuf;
}

if (0 == rank) {
/* temporary receive buffer. See coll_basic_reduce.c for
details on sizing */
recv_buf_free = (char*) malloc(span);
if (NULL == recv_buf_free) {
err = OMPI_ERR_OUT_OF_RESOURCE;
goto cleanup;
/*
* For large payload (defined as a count greater than INT_MAX)
* to reduce the memory footprint on the root we segment the
* reductions per rank, then send to each rank.
*
* Additionally, sending the message in the coll_reduce() as
* "rcount*size" would exceed the 'int count' parameter in the
* coll_reduce() function. So another technique is required
* for count values that exceed INT_MAX.
*/
if ( OPAL_UNLIKELY(count > INT_MAX) ) {
bosilca marked this conversation as resolved.
Show resolved Hide resolved
int i;
void *sbuf_ptr;

/* Get datatype information for an individual block */
span = opal_datatype_span(&dtype->super, rcount, &gap);

if (0 == rank) {
/* temporary receive buffer. See coll_basic_reduce.c for
details on sizing */
recv_buf_free = (char*) malloc(span);
if (NULL == recv_buf_free) {
err = OMPI_ERR_OUT_OF_RESOURCE;
goto cleanup;
}
recv_buf = recv_buf_free - gap;
}

for( i = 0; i < size; ++i ) {
/* Calculate the portion of the send buffer to reduce over */
sbuf_ptr = (char*)sbuf + span * (size_t)i;

/* Reduction for this peer */
err = comm->c_coll->coll_reduce(sbuf_ptr, recv_buf, rcount,
dtype, op, 0, comm,
comm->c_coll->coll_reduce_module);
if (MPI_SUCCESS != err) {
goto cleanup;
}

/* Send reduce results to this peer */
if (0 == rank ) {
if( i == rank ) {
err = ompi_datatype_copy_content_same_ddt(dtype, rcount, rbuf, recv_buf);
} else {
err = MCA_PML_CALL(send(recv_buf, rcount, dtype, i,
MCA_COLL_BASE_TAG_REDUCE_SCATTER_BLOCK,
MCA_PML_BASE_SEND_STANDARD, comm));
}
if (MPI_SUCCESS != err) {
goto cleanup;
}
}
else if( i == rank ) {
err = MCA_PML_CALL(recv(rbuf, rcount, dtype, 0,
MCA_COLL_BASE_TAG_REDUCE_SCATTER_BLOCK,
comm, MPI_STATUS_IGNORE));
if (MPI_SUCCESS != err) {
goto cleanup;
}
}
}
recv_buf = recv_buf_free - gap;
}
else {
/* get datatype information */
span = opal_datatype_span(&dtype->super, count, &gap);

if (0 == rank) {
/* temporary receive buffer. See coll_basic_reduce.c for
details on sizing */
recv_buf_free = (char*) malloc(span);
if (NULL == recv_buf_free) {
err = OMPI_ERR_OUT_OF_RESOURCE;
goto cleanup;
}
recv_buf = recv_buf_free - gap;
}

/* reduction */
err =
comm->c_coll->coll_reduce(sbuf, recv_buf, count, dtype, op, 0,
comm, comm->c_coll->coll_reduce_module);
/* reduction */
err =
comm->c_coll->coll_reduce(sbuf, recv_buf, (int)count, dtype, op, 0,
comm, comm->c_coll->coll_reduce_module);
if (MPI_SUCCESS != err) {
goto cleanup;
}

/* scatter */
if (MPI_SUCCESS == err) {
/* scatter */
err = comm->c_coll->coll_scatter(recv_buf, rcount, dtype,
rbuf, rcount, dtype, 0,
comm, comm->c_coll->coll_scatter_module);
rbuf, rcount, dtype, 0,
comm, comm->c_coll->coll_scatter_module);
}

cleanup:
Expand Down Expand Up @@ -146,7 +214,16 @@ ompi_coll_base_reduce_scatter_block_intra_recursivedoubling(
if (comm_size < 2)
return MPI_SUCCESS;

totalcount = comm_size * rcount;
totalcount = comm_size * (size_t)rcount;
if( OPAL_UNLIKELY(totalcount > INT_MAX) ) {
/*
* Large payload collectives are not supported by this algorithm.
* The blocklens and displs calculations in the loop below
* will overflow an int data type.
* Fallback to the linear algorithm.
*/
return ompi_coll_base_reduce_scatter_block_basic_linear(sbuf, rbuf, rcount, dtype, op, comm, module);
}
ompi_datatype_type_extent(dtype, &extent);
span = opal_datatype_span(&dtype->super, totalcount, &gap);
tmpbuf_raw = malloc(span);
Expand Down Expand Up @@ -347,7 +424,8 @@ ompi_coll_base_reduce_scatter_block_intra_recursivehalving(
return ompi_coll_base_reduce_scatter_block_basic_linear(sbuf, rbuf, rcount, dtype,
op, comm, module);
}
totalcount = comm_size * rcount;

totalcount = comm_size * (size_t)rcount;
ompi_datatype_type_extent(dtype, &extent);
span = opal_datatype_span(&dtype->super, totalcount, &gap);
tmpbuf_raw = malloc(span);
Expand Down Expand Up @@ -431,22 +509,22 @@ ompi_coll_base_reduce_scatter_block_intra_recursivehalving(
* have their result calculated by the process to their
* right (rank + 1).
*/
int send_count = 0, recv_count = 0;
size_t send_count = 0, recv_count = 0;
if (vrank < vpeer) {
/* Send the right half of the buffer, recv the left half */
send_index = recv_index + mask;
send_count = rcount * ompi_range_sum(send_index, last_index - 1, nprocs_rem - 1);
recv_count = rcount * ompi_range_sum(recv_index, send_index - 1, nprocs_rem - 1);
send_count = rcount * (size_t)ompi_range_sum(send_index, last_index - 1, nprocs_rem - 1);
recv_count = rcount * (size_t)ompi_range_sum(recv_index, send_index - 1, nprocs_rem - 1);
} else {
/* Send the left half of the buffer, recv the right half */
recv_index = send_index + mask;
send_count = rcount * ompi_range_sum(send_index, recv_index - 1, nprocs_rem - 1);
recv_count = rcount * ompi_range_sum(recv_index, last_index - 1, nprocs_rem - 1);
send_count = rcount * (size_t)ompi_range_sum(send_index, recv_index - 1, nprocs_rem - 1);
recv_count = rcount * (size_t)ompi_range_sum(recv_index, last_index - 1, nprocs_rem - 1);
}
ptrdiff_t rdispl = rcount * ((recv_index <= nprocs_rem - 1) ?
2 * recv_index : nprocs_rem + recv_index);
ptrdiff_t sdispl = rcount * ((send_index <= nprocs_rem - 1) ?
2 * send_index : nprocs_rem + send_index);
ptrdiff_t rdispl = rcount * (size_t)((recv_index <= nprocs_rem - 1) ?
2 * recv_index : nprocs_rem + recv_index);
ptrdiff_t sdispl = rcount * (size_t)((send_index <= nprocs_rem - 1) ?
2 * send_index : nprocs_rem + send_index);
struct ompi_request_t *request = NULL;

if (recv_count > 0) {
Expand Down Expand Up @@ -587,7 +665,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly(
sbuf, rbuf, rcount, dtype, op, comm, module);
}

totalcount = comm_size * rcount;
totalcount = comm_size * (size_t)rcount;
ompi_datatype_type_extent(dtype, &extent);
span = opal_datatype_span(&dtype->super, totalcount, &gap);
tmpbuf[0] = malloc(span);
Expand Down Expand Up @@ -677,13 +755,17 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly(
/* Send the upper half of reduction buffer, recv the lower half */
recv_index += nblocks;
}
int send_count = rcount * ompi_range_sum(send_index,
send_index + nblocks - 1, nprocs_rem - 1);
int recv_count = rcount * ompi_range_sum(recv_index,
recv_index + nblocks - 1, nprocs_rem - 1);
ptrdiff_t sdispl = rcount * ((send_index <= nprocs_rem - 1) ?
size_t send_count = rcount *
(size_t)ompi_range_sum(send_index,
send_index + nblocks - 1,
nprocs_rem - 1);
size_t recv_count = rcount *
(size_t)ompi_range_sum(recv_index,
recv_index + nblocks - 1,
nprocs_rem - 1);
ptrdiff_t sdispl = rcount * (size_t)((send_index <= nprocs_rem - 1) ?
2 * send_index : nprocs_rem + send_index);
ptrdiff_t rdispl = rcount * ((recv_index <= nprocs_rem - 1) ?
ptrdiff_t rdispl = rcount * (size_t)((recv_index <= nprocs_rem - 1) ?
2 * recv_index : nprocs_rem + recv_index);

err = ompi_coll_base_sendrecv(psend + (ptrdiff_t)sdispl * extent, send_count,
Expand Down Expand Up @@ -719,7 +801,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly(
* Process has two blocks: for excluded process and own.
* Send result to the excluded process.
*/
ptrdiff_t sdispl = rcount * ((send_index <= nprocs_rem - 1) ?
ptrdiff_t sdispl = rcount * (size_t)((send_index <= nprocs_rem - 1) ?
2 * send_index : nprocs_rem + send_index);
err = MCA_PML_CALL(send(psend + (ptrdiff_t)sdispl * extent,
rcount, dtype, peer - 1,
Expand All @@ -729,7 +811,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly(
}

/* Send result to a remote process according to a mirror permutation */
ptrdiff_t sdispl = rcount * ((send_index <= nprocs_rem - 1) ?
ptrdiff_t sdispl = rcount * (size_t)((send_index <= nprocs_rem - 1) ?
2 * send_index : nprocs_rem + send_index);
/* If process has two blocks, then send the second block (own block) */
if (vpeer < nprocs_rem)
Expand Down Expand Up @@ -821,7 +903,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly_pof2(
if (rcount == 0 || comm_size < 2)
return MPI_SUCCESS;

totalcount = comm_size * rcount;
totalcount = comm_size * (size_t)rcount;
ompi_datatype_type_extent(dtype, &extent);
span = opal_datatype_span(&dtype->super, totalcount, &gap);
tmpbuf[0] = malloc(span);
Expand All @@ -843,7 +925,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly_pof2(
if (MPI_SUCCESS != err) { goto cleanup_and_return; }
}

int nblocks = totalcount, send_index = 0, recv_index = 0;
size_t nblocks = totalcount, send_index = 0, recv_index = 0;
for (int mask = 1; mask < comm_size; mask <<= 1) {
int peer = rank ^ mask;
nblocks /= 2;
Expand Down