Skip to content

Commit

Permalink
Add private clause to paralle op's definition.
Browse files Browse the repository at this point in the history
  • Loading branch information
ergawy committed Jan 30, 2024
1 parent 8fef2e4 commit 7ff7b4b
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 5 deletions.
3 changes: 2 additions & 1 deletion flang/lib/Lower/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2572,11 +2572,12 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
/*resultTypes=*/mlir::TypeRange(), ifClauseOperand,
numThreadsClauseOperand, allocateOperands, allocatorOperands,
reductionVars,
/*private_vars=*/mlir::ValueRange(),
reductionDeclSymbols.empty()
? nullptr
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
reductionDeclSymbols),
procBindKindAttr);
procBindKindAttr, /*private_inits*/ nullptr);
}

static mlir::omp::SectionOp
Expand Down
8 changes: 7 additions & 1 deletion mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,10 @@ def ParallelOp : OpenMP_Op<"parallel", [
Variadic<AnyType>:$allocate_vars,
Variadic<AnyType>:$allocators_vars,
Variadic<OpenMP_PointerLikeType>:$reduction_vars,
Variadic<AnyType>:$private_vars,
OptionalAttr<SymbolRefArrayAttr>:$reductions,
OptionalAttr<ProcBindKindAttr>:$proc_bind_val);
OptionalAttr<ProcBindKindAttr>:$proc_bind_val,
OptionalAttr<SymbolRefArrayAttr>:$private_inits);

let regions = (region AnyRegion:$region);

Expand All @@ -213,6 +215,10 @@ def ParallelOp : OpenMP_Op<"parallel", [
$allocators_vars, type($allocators_vars)
) `)`
| `proc_bind` `(` custom<ClauseAttr>($proc_bind_val) `)`
| `private` `(`
custom<PrivateVarList>(
$private_vars, type($private_vars), $private_inits
) `)`
) $region attr-dict
}];
let hasVerifier = 1;
Expand Down
4 changes: 3 additions & 1 deletion mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -419,8 +419,10 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
/* allocate_vars = */ llvm::SmallVector<Value>{},
/* allocators_vars = */ llvm::SmallVector<Value>{},
/* reduction_vars = */ llvm::SmallVector<Value>{},
/*private_vars=*/mlir::ValueRange{},
/* reductions = */ ArrayAttr{},
/* proc_bind_val = */ omp::ClauseProcBindKindAttr{});
/* proc_bind_val = */ omp::ClauseProcBindKindAttr{},
/*private_inits*/ nullptr);
{

OpBuilder::InsertionGuard guard(rewriter);
Expand Down
83 changes: 81 additions & 2 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -989,8 +989,9 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
ParallelOp::build(
builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
/*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
/*reduction_vars=*/ValueRange(), /*reductions=*/nullptr,
/*proc_bind_val=*/nullptr);
/*reduction_vars=*/ValueRange(), /*private_vars=*/ValueRange(),
/*reductions=*/nullptr,
/*proc_bind_val=*/nullptr, /*private_inits*/ nullptr);
state.addAttributes(attributes);
}

Expand Down Expand Up @@ -1607,6 +1608,84 @@ void PrivateClauseOp::build(OpBuilder &odsBuilder, OperationState &odsState,
SmallVector<Location>(1, odsState.location));
}

static ParseResult parsePrivateVarList(
OpAsmParser &parser,
llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> &privateVarsOperands,
llvm::SmallVector<Type, 1> &privateVarsTypes, ArrayAttr &privateInitsAttr) {
SymbolRefAttr privatizerSym;
OpAsmParser::UnresolvedOperand arg;
OpAsmParser::UnresolvedOperand blockArg;
Type argType;

SmallVector<SymbolRefAttr> privateInitsVec;

auto parsePrivatizers = [&]() -> ParseResult {
if (parser.parseAttribute(privatizerSym) || parser.parseOperand(arg) ||
parser.parseArrow() || parser.parseOperand(blockArg)) {
return failure();
}

privateInitsVec.push_back(privatizerSym);
privateVarsOperands.push_back(arg);
return success();
};

auto parseTypes = [&]() -> ParseResult {
if (parser.parseType(argType))
return failure();
privateVarsTypes.push_back(argType);
return success();
};

if (parser.parseCommaSeparatedList(parsePrivatizers))
return failure();

SmallVector<Attribute> privateInits(privateInitsVec.begin(),
privateInitsVec.end());
privateInitsAttr = ArrayAttr::get(parser.getContext(), privateInits);

if (parser.parseColon())
return failure();

if (parser.parseCommaSeparatedList(parseTypes))
return failure();

return success();
}

static void printPrivateVarList(OpAsmPrinter &printer, Operation *op,
OperandRange privateVars,
TypeRange privateVarTypes,
std::optional<ArrayAttr> privateInitsAttr) {
auto &region = op->getRegion(0);
unsigned argIndex = 0;
assert(privateVars.size() == privateVarTypes.size() &&
((privateVars.empty()) ||
(*privateInitsAttr &&
(privateInitsAttr->size() == privateVars.size()))));

for (const auto &privateVar : privateVars) {
assert(privateInitsAttr);
const auto &blockArg = region.front().getArgument(argIndex);
const auto &privateInitSym = (*privateInitsAttr)[argIndex];
printer << privateInitSym << " " << privateVar << " -> " << blockArg;

argIndex++;
if (argIndex < privateVars.size())
printer << ", ";
}

printer << " : ";

argIndex = 0;
for (const auto &mapType : privateVarTypes) {
printer << mapType;
argIndex++;
if (argIndex < privateVarTypes.size())
printer << ", ";
}
}

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"

Expand Down
38 changes: 38 additions & 0 deletions mlir/test/Dialect/OpenMP/roundtrip.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// RUN: fir-opt -verify-diagnostics %s | fir-opt | FileCheck %s

// CHECK-LABEL: _QPprivate_clause
func.func @_QPprivate_clause() {
%0 = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFprivate_clause_allocatableEx"}
%1 = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFprivate_clause_allocatableEy"}

// CHECK: omp.parallel private(@x.privatizer %0 -> %arg0, @y.privatizer %1 -> %arg1 : !fir.ref<i32>, !fir.ref<i32>)
omp.parallel private(@x.privatizer %0 -> %arg0, @y.privatizer %1 -> %arg1 : !fir.ref<i32>, !fir.ref<i32>) {
// CHECK: bb0(%arg0: {{.*}}, %arg1: {{.*}}):
^bb0(%arg0 : !fir.ref<i32>, %arg1 : !fir.ref<i32>):
omp.terminator
}
return
}

// CHECK: "omp.private"() <{function_type = (!fir.ref<i32>) -> !fir.ref<i32>, sym_name = "x.privatizer"}> ({
"omp.private"() <{function_type = (!fir.ref<i32>) -> !fir.ref<i32>, sym_name = "x.privatizer"}> ({
// CHECK: ^bb0(%arg0: {{.*}}):
^bb0(%arg0: !fir.ref<i32>):

// CHECK: %0 = fir.alloca i32 {bindc_name = "x", pinned, uniq_name = "_QFprivate_clause_allocatableEx"}
%0 = fir.alloca i32 {bindc_name = "x", pinned, uniq_name = "_QFprivate_clause_allocatableEx"}

// CHECK: omp.yield(%0 : !fir.ref<i32>)
omp.yield(%0 : !fir.ref<i32>)
}) : () -> ()

// CHECK: "omp.private"() <{function_type = (!fir.ref<i32>) -> !fir.ref<i32>, sym_name = "y.privatizer"}> ({
"omp.private"() <{function_type = (!fir.ref<i32>) -> !fir.ref<i32>, sym_name = "y.privatizer"}> ({
^bb0(%arg0: !fir.ref<i32>):

// CHECK: %0 = fir.alloca i32 {bindc_name = "y", pinned, uniq_name = "_QFprivate_clause_allocatableEy"}
%0 = fir.alloca i32 {bindc_name = "y", pinned, uniq_name = "_QFprivate_clause_allocatableEy"}

// CHECK: omp.yield(%0 : !fir.ref<i32>)
omp.yield(%0 : !fir.ref<i32>)
}) : () -> ()

0 comments on commit 7ff7b4b

Please sign in to comment.