From 89b219db94e6e6a32121f91695c8d5f11d621e9d Mon Sep 17 00:00:00 2001 From: Jason Wang Date: Wed, 5 Aug 2020 14:31:29 -0400 Subject: [PATCH] added 64-bit MPI gather into SSC helper --- source/adios2/engine/ssc/SscHelper.cpp | 180 +++++++++++++++++++++++++ source/adios2/engine/ssc/SscHelper.h | 17 +++ 2 files changed, 197 insertions(+) diff --git a/source/adios2/engine/ssc/SscHelper.cpp b/source/adios2/engine/ssc/SscHelper.cpp index 47f1e7e6e5..7026ffec99 100644 --- a/source/adios2/engine/ssc/SscHelper.cpp +++ b/source/adios2/engine/ssc/SscHelper.cpp @@ -281,6 +281,186 @@ bool AreSameDims(const Dims &a, const Dims &b) return true; } +void MPI_Gatherv64(const void *sendbuf, uint64_t sendcount, + MPI_Datatype sendtype, void *recvbuf, + const uint64_t *recvcounts, const uint64_t *displs, + MPI_Datatype recvtype, int root, MPI_Comm comm, + const int chunksize) +{ + + int mpiSize; + int mpiRank; + MPI_Comm_size(comm, &mpiSize); + MPI_Comm_rank(comm, &mpiRank); + + int recvTypeSize; + int sendTypeSize; + + MPI_Type_size(recvtype, &recvTypeSize); + MPI_Type_size(sendtype, &sendTypeSize); + + std::vector requests; + if (mpiRank == root) + { + for (int i = 0; i < mpiSize; ++i) + { + uint64_t recvcount = recvcounts[i]; + while (recvcount > 0) + { + requests.emplace_back(); + if (recvcount > chunksize) + { + MPI_Irecv(reinterpret_cast(recvbuf) + + (displs[i] + recvcounts[i] - recvcount) * + recvTypeSize, + chunksize, recvtype, i, 0, comm, + &requests.back()); + recvcount -= chunksize; + } + else + { + MPI_Irecv(reinterpret_cast(recvbuf) + + (displs[i] + recvcounts[i] - recvcount) * + recvTypeSize, + static_cast(recvcount), recvtype, i, 0, comm, + &requests.back()); + recvcount = 0; + } + } + } + } + + uint64_t sendcountvar = sendcount; + + while (sendcountvar > 0) + { + requests.emplace_back(); + if (sendcountvar > chunksize) + { + MPI_Isend(reinterpret_cast(sendbuf) + + (sendcount - sendcountvar) * sendTypeSize, + chunksize, sendtype, root, 0, comm, &requests.back()); + sendcountvar -= chunksize; + } + else + { + MPI_Isend(reinterpret_cast(sendbuf) + + (sendcount - sendcountvar) * sendTypeSize, + static_cast(sendcountvar), sendtype, root, 0, comm, + &requests.back()); + sendcountvar = 0; + } + } + + MPI_Waitall(requests.size(), requests.data(), MPI_STATUSES_IGNORE); +} + +void MPI_Gatherv64OneSidedPull(const void *sendbuf, uint64_t sendcount, + MPI_Datatype sendtype, void *recvbuf, + const uint64_t *recvcounts, + const uint64_t *displs, MPI_Datatype recvtype, + int root, MPI_Comm comm, const int chunksize) +{ + + int mpiSize; + int mpiRank; + MPI_Comm_size(comm, &mpiSize); + MPI_Comm_rank(comm, &mpiRank); + + int recvTypeSize; + int sendTypeSize; + + MPI_Type_size(recvtype, &recvTypeSize); + MPI_Type_size(sendtype, &sendTypeSize); + + MPI_Win win; + MPI_Win_create(const_cast(sendbuf), sendcount * sendTypeSize, + sendTypeSize, MPI_INFO_NULL, comm, &win); + + if (mpiRank == root) + { + for (int i = 0; i < mpiSize; ++i) + { + uint64_t recvcount = recvcounts[i]; + while (recvcount > 0) + { + if (recvcount > chunksize) + { + MPI_Get(reinterpret_cast(recvbuf) + + (displs[i] + recvcounts[i] - recvcount) * + recvTypeSize, + chunksize, recvtype, i, recvcounts[i] - recvcount, + chunksize, recvtype, win); + recvcount -= chunksize; + } + else + { + MPI_Get(reinterpret_cast(recvbuf) + + (displs[i] + recvcounts[i] - recvcount) * + recvTypeSize, + static_cast(recvcount), recvtype, i, + recvcounts[i] - recvcount, + static_cast(recvcount), recvtype, win); + recvcount = 0; + } + } + } + } + + MPI_Win_free(&win); +} + +void MPI_Gatherv64OneSidedPush(const void *sendbuf, uint64_t sendcount, + MPI_Datatype sendtype, void *recvbuf, + const uint64_t *recvcounts, + const uint64_t *displs, MPI_Datatype recvtype, + int root, MPI_Comm comm, const int chunksize) +{ + + int mpiSize; + int mpiRank; + MPI_Comm_size(comm, &mpiSize); + MPI_Comm_rank(comm, &mpiRank); + + int recvTypeSize; + int sendTypeSize; + + MPI_Type_size(recvtype, &recvTypeSize); + MPI_Type_size(sendtype, &sendTypeSize); + + uint64_t recvsize = displs[mpiSize - 1] + recvcounts[mpiSize - 1]; + + MPI_Win win; + MPI_Win_create(recvbuf, recvsize * recvTypeSize, recvTypeSize, + MPI_INFO_NULL, comm, &win); + + uint64_t sendcountvar = sendcount; + + while (sendcountvar > 0) + { + if (sendcountvar > chunksize) + { + MPI_Put(reinterpret_cast(sendbuf) + + (sendcount - sendcountvar) * sendTypeSize, + chunksize, sendtype, root, + displs[mpiRank] + sendcount - sendcountvar, chunksize, + sendtype, win); + sendcountvar -= chunksize; + } + else + { + MPI_Put(reinterpret_cast(sendbuf) + + (sendcount - sendcountvar) * sendTypeSize, + sendcountvar, sendtype, root, + displs[mpiRank] + sendcount - sendcountvar, sendcountvar, + sendtype, win); + sendcountvar = 0; + } + } + + MPI_Win_free(&win); +} + void PrintDims(const Dims &dims, const std::string &label) { std::cout << label; diff --git a/source/adios2/engine/ssc/SscHelper.h b/source/adios2/engine/ssc/SscHelper.h index 82139e3a06..da75a0004b 100644 --- a/source/adios2/engine/ssc/SscHelper.h +++ b/source/adios2/engine/ssc/SscHelper.h @@ -15,6 +15,7 @@ #include "adios2/core/IO.h" #include "nlohmann/json.hpp" #include +#include #include namespace adios2 @@ -70,6 +71,22 @@ void JsonToBlockVecVec(const nlohmann::json &input, BlockVecVec &output); void JsonToBlockVecVec(const std::vector &input, BlockVecVec &output); void JsonToBlockVecVec(const std::string &input, BlockVecVec &output); +void MPI_Gatherv64OneSidedPush( + const void *sendbuf, uint64_t sendcount, MPI_Datatype sendtype, + void *recvbuf, const uint64_t *recvcounts, const uint64_t *displs, + MPI_Datatype recvtype, int root, MPI_Comm comm, + const int chunksize = std::numeric_limits::max()); +void MPI_Gatherv64OneSidedPull( + const void *sendbuf, uint64_t sendcount, MPI_Datatype sendtype, + void *recvbuf, const uint64_t *recvcounts, const uint64_t *displs, + MPI_Datatype recvtype, int root, MPI_Comm comm, + const int chunksize = std::numeric_limits::max()); +void MPI_Gatherv64(const void *sendbuf, uint64_t sendcount, + MPI_Datatype sendtype, void *recvbuf, + const uint64_t *recvcounts, const uint64_t *displs, + MPI_Datatype recvtype, int root, MPI_Comm comm, + const int chunksize = std::numeric_limits::max()); + bool AreSameDims(const Dims &a, const Dims &b); } // end namespace ssc