From 9aa0b8d55bc9c2e8d8fb34eccb491007f41986a4 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 6 Oct 2023 12:41:44 -0500 Subject: [PATCH] Mark mpi functions of booleans as inactive (#1464) --- enzyme/Enzyme/Enzyme.cpp | 37 +++++++++++++++++++++ enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 13 ++++++-- 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 275951c1d638..7c14c5caf2d4 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -239,6 +239,43 @@ void attributeKnownFunctions(llvm::Function &F) { F.addParamAttr(2, Attribute::WriteOnly); F.addParamAttr(2, Attribute::NoCapture); } + // Map of MPI function name to the arg index of its type argument + std::map MPI_TYPE_ARGS = { + {"MPI_Send", 2}, {"MPI_Ssend", 2}, {"MPI_Bsend", 2}, + {"MPI_Recv", 2}, {"MPI_Brecv", 2}, {"PMPI_Send", 2}, + {"PMPI_Ssend", 2}, {"PMPI_Bsend", 2}, {"PMPI_Recv", 2}, + {"PMPI_Brecv", 2}, + + {"MPI_Isend", 2}, {"MPI_Irecv", 2}, {"PMPI_Isend", 2}, + {"PMPI_Irecv", 2}, + + {"MPI_Reduce", 3}, {"PMPI_Reduce", 3}, + + {"MPI_Allreduce", 3}, {"PMPI_Allreduce", 3}}; + { + auto found = MPI_TYPE_ARGS.find(F.getName().str()); + if (found != MPI_TYPE_ARGS.end()) { + for (auto user : F.users()) { + if (auto CI = dyn_cast(user)) + if (CI->getCalledFunction() == &F) { + if (Constant *C = + dyn_cast(CI->getArgOperand(found->second))) { + while (ConstantExpr *CE = dyn_cast(C)) { + C = CE->getOperand(0); + } + if (auto GV = dyn_cast(C)) { + if (GV->getName() == "ompi_mpi_cxx_bool") { + CI->addAttribute( + AttributeList::FunctionIndex, + Attribute::get(CI->getContext(), "enzyme_inactive")); + } + } + } + } + } + } + } + if (F.getName() == "omp_get_max_threads" || F.getName() == "omp_get_thread_num") { #if LLVM_VERSION_MAJOR >= 16 diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index d7b1132f2f2b..1276aa310769 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -4031,6 +4031,8 @@ void TypeAnalyzer::visitCallBase(CallBase &call) { buf.insert({0}, Type::getDoubleTy(C->getContext())); } else if (GV->getName() == "ompi_mpi_float") { buf.insert({0}, Type::getFloatTy(C->getContext())); + } else if (GV->getName() == "ompi_mpi_cxx_bool") { + buf.insert({0}, BaseType::Integer); } } else if (auto CI = dyn_cast(C)) { // MPICH @@ -4051,7 +4053,8 @@ void TypeAnalyzer::visitCallBase(CallBase &call) { updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); return; } - if (funcName == "MPI_Isend" || funcName == "MPI_Irecv") { + if (funcName == "MPI_Isend" || funcName == "MPI_Irecv" || + funcName == "PMPI_Isend" || funcName == "PMPI_Irecv") { TypeTree buf = TypeTree(BaseType::Pointer); if (Constant *C = dyn_cast(call.getOperand(2))) { @@ -4063,6 +4066,8 @@ void TypeAnalyzer::visitCallBase(CallBase &call) { buf.insert({0}, Type::getDoubleTy(C->getContext())); } else if (GV->getName() == "ompi_mpi_float") { buf.insert({0}, Type::getFloatTy(C->getContext())); + } else if (GV->getName() == "ompi_mpi_cxx_bool") { + buf.insert({0}, BaseType::Integer); } } else if (auto CI = dyn_cast(C)) { // MPICH @@ -4137,6 +4142,8 @@ void TypeAnalyzer::visitCallBase(CallBase &call) { buf.insert({0}, Type::getDoubleTy(C->getContext())); } else if (GV->getName() == "ompi_mpi_float") { buf.insert({0}, Type::getFloatTy(C->getContext())); + } else if (GV->getName() == "ompi_mpi_cxx_bool") { + buf.insert({0}, BaseType::Integer); } } else if (auto CI = dyn_cast(C)) { // MPICH @@ -4164,7 +4171,7 @@ void TypeAnalyzer::visitCallBase(CallBase &call) { updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); return; } - if (funcName == "MPI_Allreduce") { + if (funcName == "MPI_Allreduce" || funcName == "PMPI_Allreduce") { TypeTree buf = TypeTree(BaseType::Pointer); if (Constant *C = dyn_cast(call.getOperand(3))) { @@ -4176,6 +4183,8 @@ void TypeAnalyzer::visitCallBase(CallBase &call) { buf.insert({0}, Type::getDoubleTy(C->getContext())); } else if (GV->getName() == "ompi_mpi_float") { buf.insert({0}, Type::getFloatTy(C->getContext())); + } else if (GV->getName() == "ompi_mpi_cxx_bool") { + buf.insert({0}, BaseType::Integer); } } else if (auto CI = dyn_cast(C)) { // MPICH