From ce748dacd8800149dddb97d095f5d41b67d0bef3 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Thu, 8 Aug 2024 16:42:46 +0200 Subject: [PATCH] [LLVM][NewPM] Add a C API for setting the PassBuilder AA pipeline. --- llvm/include/llvm-c/Transforms/PassBuilder.h | 8 ++++++++ llvm/lib/Passes/PassBuilderBindings.cpp | 19 ++++++++++++++++++- .../PassBuilderBindingsTest.cpp | 1 + 3 files changed, 27 insertions(+), 1 deletion(-) diff --git a/llvm/include/llvm-c/Transforms/PassBuilder.h b/llvm/include/llvm-c/Transforms/PassBuilder.h index d0466dd7fc0a12..03a5abaa753312 100644 --- a/llvm/include/llvm-c/Transforms/PassBuilder.h +++ b/llvm/include/llvm-c/Transforms/PassBuilder.h @@ -72,6 +72,14 @@ void LLVMPassBuilderOptionsSetVerifyEach(LLVMPassBuilderOptionsRef Options, void LLVMPassBuilderOptionsSetDebugLogging(LLVMPassBuilderOptionsRef Options, LLVMBool DebugLogging); +/** + * Specify a custom alias analysis pipeline for the PassBuilder to be used + * instead of the default one. The string argument is not copied; the caller + * is responsible for ensuring it outlives the PassBuilderOptions instance. + */ +void LLVMPassBuilderOptionsSetAAPipeline(LLVMPassBuilderOptionsRef Options, + const char *AAPipeline); + void LLVMPassBuilderOptionsSetLoopInterleaving( LLVMPassBuilderOptionsRef Options, LLVMBool LoopInterleaving); diff --git a/llvm/lib/Passes/PassBuilderBindings.cpp b/llvm/lib/Passes/PassBuilderBindings.cpp index b80dc0231ed5fd..4e12dd2226c4d2 100644 --- a/llvm/lib/Passes/PassBuilderBindings.cpp +++ b/llvm/lib/Passes/PassBuilderBindings.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "llvm-c/Transforms/PassBuilder.h" +#include "llvm/Analysis/AliasAnalysis.h" #include "llvm/IR/Module.h" #include "llvm/IR/Verifier.h" #include "llvm/Passes/PassBuilder.h" @@ -28,11 +29,14 @@ class LLVMPassBuilderOptions { public: explicit LLVMPassBuilderOptions( bool DebugLogging = false, bool VerifyEach = false, + const char *AAPipeline = nullptr, PipelineTuningOptions PTO = PipelineTuningOptions()) - : DebugLogging(DebugLogging), VerifyEach(VerifyEach), PTO(PTO) {} + : DebugLogging(DebugLogging), VerifyEach(VerifyEach), + AAPipeline(AAPipeline), PTO(PTO) {} bool DebugLogging; bool VerifyEach; + const char *AAPipeline; PipelineTuningOptions PTO; }; } // namespace llvm @@ -60,6 +64,14 @@ LLVMErrorRef LLVMRunPasses(LLVMModuleRef M, const char *Passes, FunctionAnalysisManager FAM; CGSCCAnalysisManager CGAM; ModuleAnalysisManager MAM; + if (PassOpts->AAPipeline) { + // If we have a custom AA pipeline, we need to register it _before_ calling + // registerFunctionAnalyses, or the default alias analysis pipeline is used. + AAManager AA; + if (auto Err = PB.parseAAPipeline(AA, PassOpts->AAPipeline)) + return wrap(std::move(Err)); + FAM.registerPass([&] { return std::move(AA); }); + } PB.registerLoopAnalyses(LAM); PB.registerFunctionAnalyses(FAM); PB.registerCGSCCAnalyses(CGAM); @@ -94,6 +106,11 @@ void LLVMPassBuilderOptionsSetDebugLogging(LLVMPassBuilderOptionsRef Options, unwrap(Options)->DebugLogging = DebugLogging; } +void LLVMPassBuilderOptionsSetAAPipeline(LLVMPassBuilderOptionsRef Options, + const char *AAPipeline) { + unwrap(Options)->AAPipeline = AAPipeline; +} + void LLVMPassBuilderOptionsSetLoopInterleaving( LLVMPassBuilderOptionsRef Options, LLVMBool LoopInterleaving) { unwrap(Options)->PTO.LoopInterleaving = LoopInterleaving; diff --git a/llvm/unittests/Passes/PassBuilderBindings/PassBuilderBindingsTest.cpp b/llvm/unittests/Passes/PassBuilderBindings/PassBuilderBindingsTest.cpp index ffa3fdaf6e7e6f..2b06033f0c3fa2 100644 --- a/llvm/unittests/Passes/PassBuilderBindings/PassBuilderBindingsTest.cpp +++ b/llvm/unittests/Passes/PassBuilderBindings/PassBuilderBindingsTest.cpp @@ -60,6 +60,7 @@ TEST_F(PassBuilderCTest, Basic) { LLVMPassBuilderOptionsSetLoopUnrolling(Options, 1); LLVMPassBuilderOptionsSetVerifyEach(Options, 1); LLVMPassBuilderOptionsSetDebugLogging(Options, 0); + LLVMPassBuilderOptionsSetAAPipeline(Options, "basic-aa"); if (LLVMErrorRef E = LLVMRunPasses(Module, "default", TM, Options)) { char *Msg = LLVMGetErrorMessage(E); LLVMConsumeError(E);