Skip to content

Commit

Permalink
[WIP] Delayed privatization.
Browse files Browse the repository at this point in the history
This is a PoC for delayed privatization in OpenMP. Instead of directly
emitting privatization code in the frontend, we add a new op to outline
the privatization logic for a symbol and call-like mapping that maps
from the host symbol to a block argument in the OpenMP region.

Example:
```
!$omp target private(x)
!$end omp target
```

Would be code-generated by flang as:
```
  func.func @foo() {
    omp.target x.privatizer %x -> %argx: !fir.ref<i32> {
    bb0(%argx: !fir.ref<i32>):
      // ... use %argx ....
    }
  }

  "omp.private"() <{function_type = (!fir.ref<i32>) -> !fir.ref<i32>, sym_name = "x.privatizer"}> ({
  ^bb0(%arg0: !fir.ref<i32>):
    %0 = fir.alloca i32 {bindc_name = "x", pinned, uniq_name = "_QFprivate_clause_allocatableEx"}
    %1 = fir.load %arg0 : !fir.ref<i32>
    fir.store %1 to %0 : !fir.ref<i32>
    omp.yield(%0 : !fir.ref<i32>)
  }) : () -> ()
```

Later, we would inline the delayed privatizer function-like op in the
OpenMP region to basically get the same code generated directly by the
fronend at the moment.

So far this PoC implements the following:
- Adds the delayed privatization op: `omp.private`.
- For simple symbols, emits the op.

Still TODO:
- Extend the `omp.target` op to somehow model the oulined privatization
  logic.
- Inline the outlined privatizer before emitting LLVM IR.
- Support more complex symbols like allocatables.
  • Loading branch information
ergawy committed Jan 29, 2024
1 parent ce72f78 commit 9ca8c49
Show file tree
Hide file tree
Showing 8 changed files with 206 additions and 46 deletions.
18 changes: 13 additions & 5 deletions flang/include/flang/Lower/AbstractConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "flang/Common/Fortran.h"
#include "flang/Lower/LoweringOptions.h"
#include "flang/Lower/PFTDefs.h"
#include "flang/Lower/SymbolMap.h"
#include "flang/Optimizer/Builder/BoxValue.h"
#include "flang/Semantics/symbol.h"
#include "mlir/IR/Builders.h"
Expand Down Expand Up @@ -92,7 +93,8 @@ class AbstractConverter {

/// Binds the symbol to an fir extended value. The symbol binding will be
/// added or replaced at the inner-most level of the local symbol map.
virtual void bindSymbol(SymbolRef sym, const fir::ExtendedValue &exval) = 0;
virtual void bindSymbol(SymbolRef sym, const fir::ExtendedValue &exval,
Fortran::lower::SymMap *symMap = nullptr) = 0;

/// Override lowering of expression with pre-lowered values.
/// Associate mlir::Value to evaluate::Expr. All subsequent call to
Expand All @@ -111,14 +113,16 @@ class AbstractConverter {
/// For a given symbol which is host-associated, create a clone using
/// parameters from the host-associated symbol.
virtual bool
createHostAssociateVarClone(const Fortran::semantics::Symbol &sym) = 0;
createHostAssociateVarClone(const Fortran::semantics::Symbol &sym,
Fortran::lower::SymMap *symMap = nullptr) = 0;

virtual void
createHostAssociateVarCloneDealloc(const Fortran::semantics::Symbol &sym) = 0;

virtual void copyHostAssociateVar(
const Fortran::semantics::Symbol &sym,
mlir::OpBuilder::InsertPoint *copyAssignIP = nullptr) = 0;
virtual void
copyHostAssociateVar(const Fortran::semantics::Symbol &sym,
mlir::OpBuilder::InsertPoint *copyAssignIP = nullptr,
Fortran::lower::SymMap *symMap = nullptr) = 0;

/// For a given symbol, check if it is present in the inner-most
/// level of the symbol map.
Expand Down Expand Up @@ -295,6 +299,10 @@ class AbstractConverter {
return loweringOptions;
}

virtual Fortran::lower::SymbolBox
lookupOneLevelUpSymbol(const Fortran::semantics::Symbol &sym,
Fortran::lower::SymMap *symMap = nullptr) = 0;

private:
/// Options controlling lowering behavior.
const Fortran::lower::LoweringOptions &loweringOptions;
Expand Down
3 changes: 3 additions & 0 deletions flang/include/flang/Lower/SymbolMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ struct SymbolBox : public fir::details::matcher<SymbolBox> {
[](const fir::FortranVariableOpInterface &x) {
return fir::FortranVariableOpInterface(x).getBase();
},
[](const fir::MutableBoxValue &x) {
return x.getAddr();
},
[](const auto &x) { return x.getAddr(); });
}

Expand Down
54 changes: 32 additions & 22 deletions flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -498,16 +498,18 @@ class FirConverter : public Fortran::lower::AbstractConverter {
/// Add the symbol binding to the inner-most level of the symbol map and
/// return true if it is not already present. Otherwise, return false.
bool bindIfNewSymbol(Fortran::lower::SymbolRef sym,
const fir::ExtendedValue &exval) {
if (shallowLookupSymbol(sym))
const fir::ExtendedValue &exval,
Fortran::lower::SymMap *symMap = nullptr) {
if (shallowLookupSymbol(sym, symMap))
return false;
bindSymbol(sym, exval);
bindSymbol(sym, exval, symMap);
return true;
}

void bindSymbol(Fortran::lower::SymbolRef sym,
const fir::ExtendedValue &exval) override final {
addSymbol(sym, exval, /*forced=*/true);
const fir::ExtendedValue &exval,
Fortran::lower::SymMap *symMap = nullptr) override final {
addSymbol(sym, exval, /*forced=*/true, symMap);
}

void
Expand Down Expand Up @@ -610,14 +612,15 @@ class FirConverter : public Fortran::lower::AbstractConverter {
}

bool createHostAssociateVarClone(
const Fortran::semantics::Symbol &sym) override final {
const Fortran::semantics::Symbol &sym,
Fortran::lower::SymMap *symMap = nullptr) override final {
mlir::Location loc = genLocation(sym.name());
mlir::Type symType = genType(sym);
const auto *details = sym.detailsIf<Fortran::semantics::HostAssocDetails>();
assert(details && "No host-association found");
const Fortran::semantics::Symbol &hsym = details->symbol();
mlir::Type hSymType = genType(hsym);
Fortran::lower::SymbolBox hsb = lookupSymbol(hsym);
Fortran::lower::SymbolBox hsb = lookupSymbol(hsym, symMap);

auto allocate = [&](llvm::ArrayRef<mlir::Value> shape,
llvm::ArrayRef<mlir::Value> typeParams) -> mlir::Value {
Expand Down Expand Up @@ -720,7 +723,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
// Do nothing
});

return bindIfNewSymbol(sym, exv);
return bindIfNewSymbol(sym, exv, symMap);
}

void createHostAssociateVarCloneDealloc(
Expand All @@ -745,16 +748,17 @@ class FirConverter : public Fortran::lower::AbstractConverter {

void copyHostAssociateVar(
const Fortran::semantics::Symbol &sym,
mlir::OpBuilder::InsertPoint *copyAssignIP = nullptr) override final {
mlir::OpBuilder::InsertPoint *copyAssignIP = nullptr,
Fortran::lower::SymMap *symMap = nullptr) override final {
// 1) Fetch the original copy of the variable.
assert(sym.has<Fortran::semantics::HostAssocDetails>() &&
"No host-association found");
const Fortran::semantics::Symbol &hsym = sym.GetUltimate();
Fortran::lower::SymbolBox hsb = lookupOneLevelUpSymbol(hsym);
Fortran::lower::SymbolBox hsb = lookupOneLevelUpSymbol(hsym, symMap);
assert(hsb && "Host symbol box not found");

// 2) Fetch the copied one that will mask the original.
Fortran::lower::SymbolBox sb = shallowLookupSymbol(sym);
Fortran::lower::SymbolBox sb = shallowLookupSymbol(sym, symMap);
assert(sb && "Host-associated symbol box not found");
assert(hsb.getAddr() != sb.getAddr() &&
"Host and associated symbol boxes are the same");
Expand All @@ -763,8 +767,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
mlir::OpBuilder::InsertPoint insPt = builder->saveInsertionPoint();
if (copyAssignIP && copyAssignIP->isSet())
builder->restoreInsertionPoint(*copyAssignIP);
else
else {
builder->setInsertionPointAfter(sb.getAddr().getDefiningOp());
}

Fortran::lower::SymbolBox *lhs_sb, *rhs_sb;
if (copyAssignIP && copyAssignIP->isSet() &&
Expand Down Expand Up @@ -1060,17 +1065,21 @@ class FirConverter : public Fortran::lower::AbstractConverter {

/// Find the symbol in the inner-most level of the local map or return null.
Fortran::lower::SymbolBox
shallowLookupSymbol(const Fortran::semantics::Symbol &sym) {
if (Fortran::lower::SymbolBox v = localSymbols.shallowLookupSymbol(sym))
shallowLookupSymbol(const Fortran::semantics::Symbol &sym,
Fortran::lower::SymMap *symMap = nullptr) {
auto &map = (symMap == nullptr ? localSymbols : *symMap);
if (Fortran::lower::SymbolBox v = map.shallowLookupSymbol(sym))
return v;
return {};
}

/// Find the symbol in one level up of symbol map such as for host-association
/// in OpenMP code or return null.
Fortran::lower::SymbolBox
lookupOneLevelUpSymbol(const Fortran::semantics::Symbol &sym) {
if (Fortran::lower::SymbolBox v = localSymbols.lookupOneLevelUpSymbol(sym))
lookupOneLevelUpSymbol(const Fortran::semantics::Symbol &sym,
Fortran::lower::SymMap *symMap = nullptr) override {
auto &map = (symMap == nullptr ? localSymbols : *symMap);
if (Fortran::lower::SymbolBox v = map.lookupOneLevelUpSymbol(sym))
return v;
return {};
}
Expand All @@ -1079,15 +1088,16 @@ class FirConverter : public Fortran::lower::AbstractConverter {
/// already in the map and \p forced is `false`, the map is not updated.
/// Instead the value `false` is returned.
bool addSymbol(const Fortran::semantics::SymbolRef sym,
fir::ExtendedValue val, bool forced = false) {
if (!forced && lookupSymbol(sym))
fir::ExtendedValue val, bool forced = false,
Fortran::lower::SymMap *symMap = nullptr) {
auto &map = (symMap == nullptr ? localSymbols : *symMap);
if (!forced && lookupSymbol(sym, &map))
return false;
if (lowerToHighLevelFIR()) {
Fortran::lower::genDeclareSymbol(*this, localSymbols, sym, val,
fir::FortranVariableFlagsEnum::None,
forced);
Fortran::lower::genDeclareSymbol(
*this, map, sym, val, fir::FortranVariableFlagsEnum::None, forced);
} else {
localSymbols.addSymbol(sym, val, forced);
map.addSymbol(sym, val, forced);
}
return true;
}
Expand Down
83 changes: 68 additions & 15 deletions flang/lib/Lower/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,15 @@ class DataSharingProcessor {
void collectSymbolsForPrivatization();
void insertBarrier();
void collectDefaultSymbols();
void privatize();
void
privatize(llvm::SetVector<mlir::omp::PrivateClauseOp> *privateInitializers);
void defaultPrivatize();
void copyLastPrivatize(mlir::Operation *op);
void insertLastPrivateCompare(mlir::Operation *op);
void cloneSymbol(const Fortran::semantics::Symbol *sym);
void copyFirstPrivateSymbol(const Fortran::semantics::Symbol *sym);
void cloneSymbol(const Fortran::semantics::Symbol *sym,
Fortran::lower::SymMap *symMap = nullptr);
void copyFirstPrivateSymbol(const Fortran::semantics::Symbol *sym,
Fortran::lower::SymMap *symMap);
void copyLastPrivateSymbol(const Fortran::semantics::Symbol *sym,
mlir::OpBuilder::InsertPoint *lastPrivIP);
void insertDeallocs();
Expand All @@ -197,7 +200,8 @@ class DataSharingProcessor {
// Step2 performs the copying for lastprivates and requires knowledge of the
// MLIR operation to insert the last private update. Step2 adds
// dealocation code as well.
void processStep1();
void processStep1(llvm::SetVector<mlir::omp::PrivateClauseOp>
*privateInitializers = nullptr);
void processStep2(mlir::Operation *op, bool isLoop);

void setLoopIV(mlir::Value iv) {
Expand All @@ -206,10 +210,11 @@ class DataSharingProcessor {
}
};

void DataSharingProcessor::processStep1() {
void DataSharingProcessor::processStep1(
llvm::SetVector<mlir::omp::PrivateClauseOp> *privateInitializers) {
collectSymbolsForPrivatization();
collectDefaultSymbols();
privatize();
privatize(privateInitializers);
defaultPrivatize();
insertBarrier();
}
Expand Down Expand Up @@ -239,20 +244,23 @@ void DataSharingProcessor::insertDeallocs() {
}
}

void DataSharingProcessor::cloneSymbol(const Fortran::semantics::Symbol *sym) {
void DataSharingProcessor::cloneSymbol(const Fortran::semantics::Symbol *sym,
Fortran::lower::SymMap *symMap) {
// Privatization for symbols which are pre-determined (like loop index
// variables) happen separately, for everything else privatize here.
if (sym->test(Fortran::semantics::Symbol::Flag::OmpPreDetermined))
return;
bool success = converter.createHostAssociateVarClone(*sym);
bool success = converter.createHostAssociateVarClone(*sym, symMap);
(void)success;
assert(success && "Privatization failed due to existing binding");
}

void DataSharingProcessor::copyFirstPrivateSymbol(
const Fortran::semantics::Symbol *sym) {
if (sym->test(Fortran::semantics::Symbol::Flag::OmpFirstPrivate))
converter.copyHostAssociateVar(*sym);
const Fortran::semantics::Symbol *sym,
Fortran::lower::SymMap *symMap = nullptr) {
if (sym->test(Fortran::semantics::Symbol::Flag::OmpFirstPrivate)) {
converter.copyHostAssociateVar(*sym, nullptr, symMap);
}
}

void DataSharingProcessor::copyLastPrivateSymbol(
Expand Down Expand Up @@ -487,15 +495,54 @@ void DataSharingProcessor::collectDefaultSymbols() {
}
}

void DataSharingProcessor::privatize() {
void DataSharingProcessor::privatize(
llvm::SetVector<mlir::omp::PrivateClauseOp> *privateInitializers) {

for (const Fortran::semantics::Symbol *sym : privatizedSymbols) {

if (const auto *commonDet =
sym->detailsIf<Fortran::semantics::CommonBlockDetails>()) {
for (const auto &mem : commonDet->objects()) {
cloneSymbol(&*mem);
copyFirstPrivateSymbol(&*mem);
}
} else {
if (privateInitializers != nullptr) {
auto ip = firOpBuilder.saveInsertionPoint();

auto moduleOp = firOpBuilder.getInsertionBlock()
->getParentOp()
->getParentOfType<mlir::ModuleOp>();

firOpBuilder.setInsertionPoint(&moduleOp.getBodyRegion().front(),
moduleOp.getBodyRegion().front().end());

Fortran::lower::SymbolBox hsb = converter.lookupOneLevelUpSymbol(*sym);
assert(hsb && "Host symbol box not found");

auto privatizerOp = firOpBuilder.create<mlir::omp::PrivateClauseOp>(
hsb.getAddr().getLoc(), hsb.getAddr().getType(),
sym->name().ToString());
firOpBuilder.setInsertionPointToEnd(&privatizerOp.getBody().front());

Fortran::semantics::Symbol cp = *sym;
Fortran::lower::SymMap privatizerSymbolMap;
privatizerSymbolMap.addSymbol(cp, privatizerOp.getArgument(0));
privatizerSymbolMap.pushScope();

cloneSymbol(&cp, &privatizerSymbolMap);
copyFirstPrivateSymbol(&cp, &privatizerSymbolMap);

firOpBuilder.create<mlir::omp::YieldOp>(
hsb.getAddr().getLoc(),
privatizerSymbolMap.shallowLookupSymbol(cp).getAddr());

firOpBuilder.restoreInsertionPoint(ip);
}

// TODO: This will eventually be an else to the `if` above it. For now, I
// emit both the outlined privatizer AND directly emitted cloning and
// copying ops while I am testing.
cloneSymbol(sym);
copyFirstPrivateSymbol(sym);
}
Expand Down Expand Up @@ -2272,6 +2319,7 @@ static void createBodyOfOp(
llvm::SmallVector<mlir::Type> tiv(args.size(), loopVarType);
llvm::SmallVector<mlir::Location> locs(args.size(), loc);
firOpBuilder.createBlock(&op.getRegion(), {}, tiv, locs);

// The argument is not currently in memory, so make a temporary for the
// argument, and store it there, then bind that location to the argument.
mlir::Operation *storeOp = nullptr;
Expand All @@ -2291,10 +2339,11 @@ static void createBodyOfOp(

// If it is an unstructured region and is not the outer region of a combined
// construct, create empty blocks for all evaluations.
if (eval.lowerAsUnstructured() && !outerCombined)
if (eval.lowerAsUnstructured() && !outerCombined) {
Fortran::lower::createEmptyRegionBlocks<mlir::omp::TerminatorOp,
mlir::omp::YieldOp>(
firOpBuilder, eval.getNestedEvaluations());
}

// Start with privatization, so that the lowering of the nested
// code will use the right symbols.
Expand All @@ -2307,12 +2356,14 @@ static void createBodyOfOp(
if (privatize) {
if (!dsp) {
tempDsp.emplace(converter, *clauses, eval);
tempDsp->processStep1();
llvm::SetVector<mlir::omp::PrivateClauseOp> privateInitializers;
tempDsp->processStep1(&privateInitializers);
}
}

if constexpr (std::is_same_v<Op, mlir::omp::ParallelOp>) {
threadPrivatizeVars(converter, eval);

if (clauses) {
firOpBuilder.setInsertionPoint(marker);
ClauseProcessor(converter, *clauses).processCopyin();
Expand Down Expand Up @@ -2361,6 +2412,7 @@ static void createBodyOfOp(
if (exits.size() == 1)
return exits[0];
mlir::Block *exit = firOpBuilder.createBlock(&region);

for (mlir::Block *b : exits) {
firOpBuilder.setInsertionPointToEnd(b);
firOpBuilder.create<mlir::cf::BranchOp>(loc, exit);
Expand All @@ -2382,8 +2434,9 @@ static void createBodyOfOp(
assert(tempDsp.has_value());
tempDsp->processStep2(op, isLoop);
} else {
if (isLoop && args.size() > 0)
if (isLoop && args.size() > 0) {
dsp->setLoopIV(converter.getSymbolAddress(*args[0]));
}
dsp->processStep2(op, isLoop);
}
}
Expand Down
Loading

0 comments on commit 9ca8c49

Please sign in to comment.