Skip to content

Commit

Permalink
Added support to generate OpenMP parallel construct clauses, at this …
Browse files Browse the repository at this point in the history
…time for num_threads and proc_bind (#2944)

Signed-off-by: Alexandre Eichenberger <alexe@us.ibm.com>
  • Loading branch information
AlexandreEichenberger committed Sep 19, 2024
1 parent 9dd7c4a commit d03eff2
Show file tree
Hide file tree
Showing 15 changed files with 592 additions and 9 deletions.
46 changes: 45 additions & 1 deletion docs/Dialects/krnl.md
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,35 @@ Typically it is used for optional arguments used in KrnlCallop.
| :----: | ----------- |
| `none_val` | none type

### `krnl.parallel_clause` (KrnlParallelClauseOp)

_Attach OpenMP clauses to an index varialbe_


Syntax:

```
operation ::= `krnl.parallel_clause` `(` $parallel_loop_index `)` (`,` `num_threads` `(` $num_threads^ `)`)?
attr-dict `:` type($parallel_loop_index)
```

Attach OpenMP clauses to an index variable. That index variable
is used to uniquely associate a parallel loop with its clauses.

#### Attributes:

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>proc_bind</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
</table>

#### Operands:

| Operand | Description |
| :-----: | ----------- |
| `parallel_loop_index` | index
| `num_threads` | 32-bit signless integer

### `krnl.parallel` (KrnlParallelOp)

_Mark Krnl loops as parallel loops_
Expand All @@ -937,23 +966,38 @@ _Mark Krnl loops as parallel loops_
Syntax:

```
operation ::= `krnl.parallel` `(` $loops `)` attr-dict `:` type($loops)
operation ::= `krnl.parallel` `(` $loops `)` (`,` `num_threads` `(` $num_threads^ `)`)? attr-dict `:` type($loops)
```

Parallelize the specified loops. When multiple loop specifiers are passed
as parameters, there loops can be parallelized as a collapsed loop.
krnl.parallel should be placed as the last operator before krnl.iterate,
Since we do not want to parallelize the loop until we interpret krnl.block,
krnl.permute and krnl.unroll.

Optionally, a value may specifiy the number of threads requested for the
parallel loop. A proc_bind string may also be specified; valid values are
"primary", "close", or "spread". Default values are used when not specified.

```
krnl.parallel (%i0, %i1) : !Krnl.loop, !Krnl.loop
```

Traits: `AttrSizedOperandSegments`

#### Attributes:

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>proc_bind</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
</table>

#### Operands:

| Operand | Description |
| :-----: | ----------- |
| `loops` | variadic of any type
| `num_threads` | 32-bit signless integer

### `krnl.permute` (KrnlPermuteOp)

Expand Down
1 change: 1 addition & 0 deletions src/Compiler/CompilerPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ void addKrnlToLLVMPasses(
// The alloca_scope ops are somewhat fragile; canonicalize remove them when
// redundant, which helps reliability of the compilation of these ops.
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(onnx_mlir::createProcessKrnlParallelClausePass());
}

// The pass below is needed for subview and collapseShape.. Unfortunately,
Expand Down
22 changes: 22 additions & 0 deletions src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,10 @@ static LogicalResult interpretOperation(Operation *op, OpBuilder &builder,
<< parallelOp << "\n");
// ToFix handle multiple parallel loop
ValueRange loopRefs = parallelOp.getLoops();
Value numThreads = parallelOp.getNumThreads();
StringAttr procBind = parallelOp.getProcBindAttr();
bool needParallelClause =
numThreads || (procBind && procBind.getValue().size() > 0);

// Obtain the the reference the loop that needs to be parallelized
for (Value loopRef : loopRefs) {
Expand Down Expand Up @@ -778,6 +782,23 @@ static LogicalResult interpretOperation(Operation *op, OpBuilder &builder,
parallelLoop.getRegion().takeBody(loopToParallel.getRegion());
Operation *yieldOp = &parallelLoop.getBody()->back();
yieldOp->setOperands(reducedValues);
if (needParallelClause) {
// Use clause only for the first one (expected the outermost one).
// Ideally, we would generate here a single, multi-dimensional
// AffineParallelOp, and we would not need to reset the flag.
needParallelClause = false;
// Currently approach: insert after yield and then move before it.
PatternRewriter::InsertionGuard insertGuard(builder);
builder.setInsertionPointAfter(yieldOp);
// Get induction variable.
ValueRange optionalLoopIndices = parallelLoop.getIVs();
assert(optionalLoopIndices.size() >= 1 &&
"expected at least one loop index");
Value parallelLoopIndex = optionalLoopIndices[0];
Operation *newOp = opBuilder.create<KrnlParallelClauseOp>(
loc, parallelLoopIndex, numThreads, procBind);
newOp->moveBefore(yieldOp);
}
// Replace the affine.forOp with affine.parallelOp in loopRefToTop
loopRefToOp[loopRef] = parallelLoop;
loopToParallel.erase();
Expand Down Expand Up @@ -975,6 +996,7 @@ void ConvertKrnlToAffinePass::runOnOperation() {
target.addIllegalOp<KrnlCopyToBufferOp>();
target.addIllegalOp<KrnlCopyFromBufferOp>();
target.addIllegalOp<KrnlPrefetchOp>();
target.addLegalOp<KrnlParallelClauseOp>();
target.addLegalOp<AffineYieldOp>();
target.addLegalOp<AffineLoadOp>();
target.addLegalOp<AffineStoreOp>();
Expand Down
21 changes: 20 additions & 1 deletion src/Dialect/Krnl/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,26 @@ ValueRange KrnlBuilder::getInductionVarValue(ValueRange loops) const {
}

void KrnlBuilder::parallel(ValueRange loops) const {
b().template create<KrnlParallelOp>(loc(), loops);
Value noneValue;
StringAttr noneStrAttr;
b().template create<KrnlParallelOp>(loc(), loops, noneValue, noneStrAttr);
}

void KrnlBuilder::parallel(
ValueRange loops, Value numThreads, StringAttr procBind) const {
if (procBind.getValue().size() > 0) {
std::string str = procBind.getValue().str();
assert((str == "primary" || str == "close" || str == "spread") &&
"expected primary, close, or spread for proc_bind");
}
b().template create<KrnlParallelOp>(loc(), loops, numThreads, procBind);
}

void KrnlBuilder::parallelClause(
Value parallelLoopIndex, Value numThreads, StringAttr procBind) const {
// No need to check procBind as its value are derived from parallel(...).
b().template create<KrnlParallelClauseOp>(
loc(), parallelLoopIndex, numThreads, procBind);
}

void KrnlBuilder::iterate(ValueRange originalLoops, ValueRange optimizedLoops,
Expand Down
4 changes: 4 additions & 0 deletions src/Dialect/Krnl/DialectBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ struct KrnlBuilder : public DialectBuilder {
void permute(mlir::ValueRange loops, mlir::ArrayRef<int64_t> map) const;
mlir::ValueRange getInductionVarValue(mlir::ValueRange loops) const;
void parallel(mlir::ValueRange loops) const;
void parallel(mlir::ValueRange loops, mlir::Value numThreads,
mlir::StringAttr procBind) const;
void parallelClause(mlir::Value parallelLoopIndex, mlir::Value numThreads,
mlir::StringAttr procBind) const;

// Iterate over optimized loops given the original loops, lbs and ubs. Lambda
// function implement the body of the loop, and receive a KRNL builder and the
Expand Down
30 changes: 27 additions & 3 deletions src/Dialect/Krnl/Krnl.td
Original file line number Diff line number Diff line change
Expand Up @@ -514,23 +514,47 @@ def KrnlUnrollOp : Op<Krnl_Dialect, "unroll"> {
}];
}

def KrnlParallelOp : Op<Krnl_Dialect, "parallel"> {
def KrnlParallelOp : Op<Krnl_Dialect, "parallel", [AttrSizedOperandSegments]> {
let summary = "Mark Krnl loops as parallel loops";
let description = [{
Parallelize the specified loops. When multiple loop specifiers are passed
as parameters, there loops can be parallelized as a collapsed loop.
krnl.parallel should be placed as the last operator before krnl.iterate,
Since we do not want to parallelize the loop until we interpret krnl.block,
krnl.permute and krnl.unroll.

Optionally, a value may specifiy the number of threads requested for the
parallel loop. A proc_bind string may also be specified; valid values are
"primary", "close", or "spread". Default values are used when not specified.

```
krnl.parallel (%i0, %i1) : !Krnl.loop, !Krnl.loop
```
}];

let arguments = (ins Variadic<AnyType>:$loops);
let arguments = (ins Variadic<AnyType>:$loops,
Optional<I32>:$num_threads,
OptionalAttr<StrAttr>:$proc_bind);

let assemblyFormat = [{
`(` $loops `)` attr-dict `:` type($loops)
`(` $loops `)` (`,` `num_threads` `(` $num_threads^ `)`)? attr-dict `:` type($loops)
}];
}

def KrnlParallelClauseOp : Op<Krnl_Dialect, "parallel_clause"> {
let summary = "Attach OpenMP clauses to an index varialbe";
let description = [{
Attach OpenMP clauses to an index variable. That index variable
is used to uniquely associate a parallel loop with its clauses.
}];

let arguments = (ins Index: $parallel_loop_index,
Optional<I32>:$num_threads,
OptionalAttr<StrAttr>:$proc_bind);

let assemblyFormat = [{
`(` $parallel_loop_index `)` (`,` `num_threads` `(` $num_threads^ `)`)?
attr-dict `:` type($parallel_loop_index)
}];
}

Expand Down
1 change: 1 addition & 0 deletions src/Pass/Passes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ void configureOnnxToKrnlLoweringPass(bool reportOnParallel,
bool parallelIsEnabled, std::string specificParallelOps, bool reportOnSimd,
bool simdIsEnabled);
std::unique_ptr<mlir::Pass> createProcessScfParallelPrivatePass();
std::unique_ptr<mlir::Pass> createProcessKrnlParallelClausePass();

#ifdef ONNX_MLIR_ENABLE_STABLEHLO
/// Add pass for lowering to Stablehlo IR.
Expand Down
4 changes: 4 additions & 0 deletions src/Tools/onnx-mlir-opt/RegisterPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ void registerOMPasses(int optLevel) {
return createProcessScfParallelPrivatePass();
});

mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
return createProcessKrnlParallelClausePass();
});

mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
return krnl::createConvertSeqToMemrefPass();
});
Expand Down
4 changes: 3 additions & 1 deletion src/Transform/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@ add_onnx_mlir_library(OMLowerKrnlRegion
MLIRTransformUtils
)

add_onnx_mlir_library(OMScfParallelPrivateRegion
add_onnx_mlir_library(OMScfParallelPrivateRegion
ProcessScfParallelPrivate.cpp
ProcessKrnlParallelClause.cpp

LINK_LIBS PUBLIC
OMSupport
MLIRTransformUtils
MLIROpenMPToLLVM
)

add_onnx_mlir_library(OMInstrument
Expand Down
Loading

0 comments on commit d03eff2

Please sign in to comment.