From fe07940cfd5507871ce2a747a6c88149cc8096af Mon Sep 17 00:00:00 2001 From: Joshua Hursey Date: Thu, 27 Jan 2022 16:41:46 -0500 Subject: [PATCH] Fix intercommunicator overflow with big payload collectives * The 'inter' collective component was multiplying the `int` count by the `int` size of the communicator which can overflow the integer. - Solution is to preserve the full `size_t` value in the compuation which the PML supports. * `allgather`, `gather`, `scatter` all overflowed in a multiply - Preserve the full `size_t` value in the multiply - allgather needed extra code to handle the bcast of the result * `allgatherv`, `gatherv`, `scatterv` all overflowed a `total` variable that accumulated over the count array. - Preserve the full `size_t` value in `total` type Signed-off-by: Joshua Hursey --- ompi/mca/coll/inter/coll_inter_allgather.c | 36 +++++++++++++++------ ompi/mca/coll/inter/coll_inter_allgatherv.c | 4 ++- ompi/mca/coll/inter/coll_inter_gather.c | 5 +-- ompi/mca/coll/inter/coll_inter_gatherv.c | 4 ++- ompi/mca/coll/inter/coll_inter_scatter.c | 5 +-- ompi/mca/coll/inter/coll_inter_scatterv.c | 4 ++- 6 files changed, 42 insertions(+), 16 deletions(-) diff --git a/ompi/mca/coll/inter/coll_inter_allgather.c b/ompi/mca/coll/inter/coll_inter_allgather.c index 6bd0e91b58d..fe867cda06a 100644 --- a/ompi/mca/coll/inter/coll_inter_allgather.c +++ b/ompi/mca/coll/inter/coll_inter_allgather.c @@ -12,6 +12,7 @@ * Copyright (c) 2006-2010 University of Houston. All rights reserved. * Copyright (c) 2015-2017 Research Organization for Information Science * and Technology (RIST). All rights reserved. + * Copyright (c) 2022 IBM Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -48,9 +49,10 @@ mca_coll_inter_allgather_inter(const void *sbuf, int scount, struct ompi_communicator_t *comm, mca_coll_base_module_t *module) { - int rank, root = 0, size, rsize, err = OMPI_SUCCESS; + int rank, root = 0, size, rsize, err = OMPI_SUCCESS, i; char *ptmp_free = NULL, *ptmp = NULL; ptrdiff_t gap, span; + void *rbuf_ptr; rank = ompi_comm_rank(comm); size = ompi_comm_size(comm->c_local_comm); @@ -76,9 +78,9 @@ mca_coll_inter_allgather_inter(const void *sbuf, int scount, if (rank == root) { /* Do a send-recv between the two root procs. to avoid deadlock */ - err = ompi_coll_base_sendrecv_actual(ptmp, scount*size, sdtype, 0, + err = ompi_coll_base_sendrecv_actual(ptmp, scount*(size_t)size, sdtype, 0, MCA_COLL_BASE_TAG_ALLGATHER, - rbuf, rcount*rsize, rdtype, 0, + rbuf, rcount*(size_t)rsize, rdtype, 0, MCA_COLL_BASE_TAG_ALLGATHER, comm, MPI_STATUS_IGNORE); if (OMPI_SUCCESS != err) { @@ -87,12 +89,28 @@ mca_coll_inter_allgather_inter(const void *sbuf, int scount, } /* bcast the message to all the local processes */ if ( rcount > 0 ) { - err = comm->c_local_comm->c_coll->coll_bcast(rbuf, rcount*rsize, rdtype, - root, comm->c_local_comm, - comm->c_local_comm->c_coll->coll_bcast_module); - if (OMPI_SUCCESS != err) { - goto exit; - } + if ( OPAL_UNLIKELY(rcount*(size_t)rsize > INT_MAX) ) { + // Sending the message in the coll_bcast as "rcount*rsize" would exceed + // the 'int count' parameter in the coll_bcast() function. Instead broadcast + // the result in "rcount" chunks to the local group. + span = opal_datatype_span(&rdtype->super, rcount, &gap); + for( i = 0; i < rsize; ++i) { + rbuf_ptr = (char*)rbuf + span * (size_t)i; + err = comm->c_local_comm->c_coll->coll_bcast(rbuf_ptr, rcount, rdtype, + root, comm->c_local_comm, + comm->c_local_comm->c_coll->coll_bcast_module); + if (OMPI_SUCCESS != err) { + goto exit; + } + } + } else { + err = comm->c_local_comm->c_coll->coll_bcast(rbuf, rcount*rsize, rdtype, + root, comm->c_local_comm, + comm->c_local_comm->c_coll->coll_bcast_module); + if (OMPI_SUCCESS != err) { + goto exit; + } + } } exit: diff --git a/ompi/mca/coll/inter/coll_inter_allgatherv.c b/ompi/mca/coll/inter/coll_inter_allgatherv.c index 0728fd28072..7a35e25a9c6 100644 --- a/ompi/mca/coll/inter/coll_inter_allgatherv.c +++ b/ompi/mca/coll/inter/coll_inter_allgatherv.c @@ -12,6 +12,7 @@ * Copyright (c) 2006-2010 University of Houston. All rights reserved. * Copyright (c) 2015-2017 Research Organization for Information Science * and Technology (RIST). All rights reserved. + * Copyright (c) 2022 IBM Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -47,7 +48,8 @@ mca_coll_inter_allgatherv_inter(const void *sbuf, int scount, struct ompi_communicator_t *comm, mca_coll_base_module_t *module) { - int i, rank, size, size_local, total=0, err; + int i, rank, size, size_local, err; + size_t total = 0; int *count=NULL,*displace=NULL; char *ptmp_free=NULL, *ptmp=NULL; ompi_datatype_t *ndtype = NULL; diff --git a/ompi/mca/coll/inter/coll_inter_gather.c b/ompi/mca/coll/inter/coll_inter_gather.c index f1a7356224d..05ffc736efb 100644 --- a/ompi/mca/coll/inter/coll_inter_gather.c +++ b/ompi/mca/coll/inter/coll_inter_gather.c @@ -12,6 +12,7 @@ * Copyright (c) 2006-2007 University of Houston. All rights reserved. * Copyright (c) 2015-2016 Research Organization for Information Science * and Technology (RIST). All rights reserved. + * Copyright (c) 2022 IBM Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -76,7 +77,7 @@ mca_coll_inter_gather_inter(const void *sbuf, int scount, comm->c_local_comm->c_coll->coll_gather_module); if (0 == rank) { /* First process sends data to the root */ - err = MCA_PML_CALL(send(ptmp, scount*size_local, sdtype, root, + err = MCA_PML_CALL(send(ptmp, scount*(size_t)size_local, sdtype, root, MCA_COLL_BASE_TAG_GATHER, MCA_PML_BASE_SEND_STANDARD, comm)); if (OMPI_SUCCESS != err) { @@ -86,7 +87,7 @@ mca_coll_inter_gather_inter(const void *sbuf, int scount, free(ptmp_free); } else { /* I am the root, loop receiving the data. */ - err = MCA_PML_CALL(recv(rbuf, rcount*size, rdtype, 0, + err = MCA_PML_CALL(recv(rbuf, rcount*(size_t)size, rdtype, 0, MCA_COLL_BASE_TAG_GATHER, comm, MPI_STATUS_IGNORE)); if (OMPI_SUCCESS != err) { diff --git a/ompi/mca/coll/inter/coll_inter_gatherv.c b/ompi/mca/coll/inter/coll_inter_gatherv.c index 5dd9f7b4b68..3ee00890348 100644 --- a/ompi/mca/coll/inter/coll_inter_gatherv.c +++ b/ompi/mca/coll/inter/coll_inter_gatherv.c @@ -12,6 +12,7 @@ * Copyright (c) 2006-2010 University of Houston. All rights reserved. * Copyright (c) 2015-2016 Research Organization for Information Science * and Technology (RIST). All rights reserved. + * Copyright (c) 2022 IBM Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -44,7 +45,8 @@ mca_coll_inter_gatherv_inter(const void *sbuf, int scount, struct ompi_communicator_t *comm, mca_coll_base_module_t *module) { - int i, rank, size, size_local, total=0, err; + int i, rank, size, size_local, err; + size_t total = 0; int *count=NULL, *displace=NULL; char *ptmp_free=NULL, *ptmp=NULL; ompi_datatype_t *ndtype; diff --git a/ompi/mca/coll/inter/coll_inter_scatter.c b/ompi/mca/coll/inter/coll_inter_scatter.c index 94871f0be93..42ab948c738 100644 --- a/ompi/mca/coll/inter/coll_inter_scatter.c +++ b/ompi/mca/coll/inter/coll_inter_scatter.c @@ -12,6 +12,7 @@ * Copyright (c) 2006-2008 University of Houston. All rights reserved. * Copyright (c) 2015-2016 Research Organization for Information Science * and Technology (RIST). All rights reserved. + * Copyright (c) 2022 IBM Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -69,7 +70,7 @@ mca_coll_inter_scatter_inter(const void *sbuf, int scount, } ptmp = ptmp_free - gap; - err = MCA_PML_CALL(recv(ptmp, rcount*size_local, rdtype, + err = MCA_PML_CALL(recv(ptmp, rcount*(size_t)size_local, rdtype, root, MCA_COLL_BASE_TAG_SCATTER, comm, MPI_STATUS_IGNORE)); if (OMPI_SUCCESS != err) { @@ -86,7 +87,7 @@ mca_coll_inter_scatter_inter(const void *sbuf, int scount, } } else { /* Root sends data to the first process in the remote group */ - err = MCA_PML_CALL(send(sbuf, scount*size, sdtype, 0, + err = MCA_PML_CALL(send(sbuf, scount*(size_t)size, sdtype, 0, MCA_COLL_BASE_TAG_SCATTER, MCA_PML_BASE_SEND_STANDARD, comm)); if (OMPI_SUCCESS != err) { diff --git a/ompi/mca/coll/inter/coll_inter_scatterv.c b/ompi/mca/coll/inter/coll_inter_scatterv.c index e0ccaedd1e6..0d0246af5be 100644 --- a/ompi/mca/coll/inter/coll_inter_scatterv.c +++ b/ompi/mca/coll/inter/coll_inter_scatterv.c @@ -12,6 +12,7 @@ * Copyright (c) 2006-2010 University of Houston. All rights reserved. * Copyright (c) 2015-2016 Research Organization for Information Science * and Technology (RIST). All rights reserved. + * Copyright (c) 2022 IBM Corporation. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -45,7 +46,8 @@ mca_coll_inter_scatterv_inter(const void *sbuf, const int *scounts, struct ompi_communicator_t *comm, mca_coll_base_module_t *module) { - int i, rank, size, err, total=0, size_local; + int i, rank, size, err, size_local; + size_t total = 0; int *counts=NULL,*displace=NULL; char *ptmp_free=NULL, *ptmp=NULL; ompi_datatype_t *ndtype;