diff --git a/docs/Dialects/krnl.md b/docs/Dialects/krnl.md
index c5c76450bd..80ad6224a9 100644
--- a/docs/Dialects/krnl.md
+++ b/docs/Dialects/krnl.md
@@ -929,6 +929,35 @@ Typically it is used for optional arguments used in KrnlCallop.
| :----: | ----------- |
| `none_val` | none type
+### `krnl.parallel_clause` (KrnlParallelClauseOp)
+
+_Attach OpenMP clauses to an index varialbe_
+
+
+Syntax:
+
+```
+operation ::= `krnl.parallel_clause` `(` $parallel_loop_index `)` (`,` `num_threads` `(` $num_threads^ `)`)?
+ attr-dict `:` type($parallel_loop_index)
+```
+
+Attach OpenMP clauses to an index variable. That index variable
+is used to uniquely associate a parallel loop with its clauses.
+
+#### Attributes:
+
+
+Attribute | MLIR Type | Description |
+proc_bind | ::mlir::StringAttr | string attribute |
+
+
+#### Operands:
+
+| Operand | Description |
+| :-----: | ----------- |
+| `parallel_loop_index` | index
+| `num_threads` | 32-bit signless integer
+
### `krnl.parallel` (KrnlParallelOp)
_Mark Krnl loops as parallel loops_
@@ -937,7 +966,7 @@ _Mark Krnl loops as parallel loops_
Syntax:
```
-operation ::= `krnl.parallel` `(` $loops `)` attr-dict `:` type($loops)
+operation ::= `krnl.parallel` `(` $loops `)` (`,` `num_threads` `(` $num_threads^ `)`)? attr-dict `:` type($loops)
```
Parallelize the specified loops. When multiple loop specifiers are passed
@@ -945,15 +974,30 @@ as parameters, there loops can be parallelized as a collapsed loop.
krnl.parallel should be placed as the last operator before krnl.iterate,
Since we do not want to parallelize the loop until we interpret krnl.block,
krnl.permute and krnl.unroll.
+
+Optionally, a value may specifiy the number of threads requested for the
+parallel loop. A proc_bind string may also be specified; valid values are
+"primary", "close", or "spread". Default values are used when not specified.
+
```
krnl.parallel (%i0, %i1) : !Krnl.loop, !Krnl.loop
```
+Traits: `AttrSizedOperandSegments`
+
+#### Attributes:
+
+
+Attribute | MLIR Type | Description |
+proc_bind | ::mlir::StringAttr | string attribute |
+
+
#### Operands:
| Operand | Description |
| :-----: | ----------- |
| `loops` | variadic of any type
+| `num_threads` | 32-bit signless integer
### `krnl.permute` (KrnlPermuteOp)
diff --git a/src/Compiler/CompilerPasses.cpp b/src/Compiler/CompilerPasses.cpp
index c8f84e5565..2d796b8f34 100644
--- a/src/Compiler/CompilerPasses.cpp
+++ b/src/Compiler/CompilerPasses.cpp
@@ -251,6 +251,7 @@ void addKrnlToLLVMPasses(
// The alloca_scope ops are somewhat fragile; canonicalize remove them when
// redundant, which helps reliability of the compilation of these ops.
pm.addPass(mlir::createCanonicalizerPass());
+ pm.addPass(onnx_mlir::createProcessKrnlParallelClausePass());
}
// The pass below is needed for subview and collapseShape.. Unfortunately,
diff --git a/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp b/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp
index cb68b58379..ee31243724 100644
--- a/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp
+++ b/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp
@@ -742,6 +742,10 @@ static LogicalResult interpretOperation(Operation *op, OpBuilder &builder,
<< parallelOp << "\n");
// ToFix handle multiple parallel loop
ValueRange loopRefs = parallelOp.getLoops();
+ Value numThreads = parallelOp.getNumThreads();
+ StringAttr procBind = parallelOp.getProcBindAttr();
+ bool needParallelClause =
+ numThreads || (procBind && procBind.getValue().size() > 0);
// Obtain the the reference the loop that needs to be parallelized
for (Value loopRef : loopRefs) {
@@ -778,6 +782,23 @@ static LogicalResult interpretOperation(Operation *op, OpBuilder &builder,
parallelLoop.getRegion().takeBody(loopToParallel.getRegion());
Operation *yieldOp = ¶llelLoop.getBody()->back();
yieldOp->setOperands(reducedValues);
+ if (needParallelClause) {
+ // Use clause only for the first one (expected the outermost one).
+ // Ideally, we would generate here a single, multi-dimensional
+ // AffineParallelOp, and we would not need to reset the flag.
+ needParallelClause = false;
+ // Currently approach: insert after yield and then move before it.
+ PatternRewriter::InsertionGuard insertGuard(builder);
+ builder.setInsertionPointAfter(yieldOp);
+ // Get induction variable.
+ ValueRange optionalLoopIndices = parallelLoop.getIVs();
+ assert(optionalLoopIndices.size() >= 1 &&
+ "expected at least one loop index");
+ Value parallelLoopIndex = optionalLoopIndices[0];
+ Operation *newOp = opBuilder.create(
+ loc, parallelLoopIndex, numThreads, procBind);
+ newOp->moveBefore(yieldOp);
+ }
// Replace the affine.forOp with affine.parallelOp in loopRefToTop
loopRefToOp[loopRef] = parallelLoop;
loopToParallel.erase();
@@ -975,6 +996,7 @@ void ConvertKrnlToAffinePass::runOnOperation() {
target.addIllegalOp();
target.addIllegalOp();
target.addIllegalOp();
+ target.addLegalOp();
target.addLegalOp();
target.addLegalOp();
target.addLegalOp();
diff --git a/src/Dialect/Krnl/DialectBuilder.cpp b/src/Dialect/Krnl/DialectBuilder.cpp
index 541c3669b0..a645018de9 100644
--- a/src/Dialect/Krnl/DialectBuilder.cpp
+++ b/src/Dialect/Krnl/DialectBuilder.cpp
@@ -155,7 +155,26 @@ ValueRange KrnlBuilder::getInductionVarValue(ValueRange loops) const {
}
void KrnlBuilder::parallel(ValueRange loops) const {
- b().template create(loc(), loops);
+ Value noneValue;
+ StringAttr noneStrAttr;
+ b().template create(loc(), loops, noneValue, noneStrAttr);
+}
+
+void KrnlBuilder::parallel(
+ ValueRange loops, Value numThreads, StringAttr procBind) const {
+ if (procBind.getValue().size() > 0) {
+ std::string str = procBind.getValue().str();
+ assert((str == "primary" || str == "close" || str == "spread") &&
+ "expected primary, close, or spread for proc_bind");
+ }
+ b().template create(loc(), loops, numThreads, procBind);
+}
+
+void KrnlBuilder::parallelClause(
+ Value parallelLoopIndex, Value numThreads, StringAttr procBind) const {
+ // No need to check procBind as its value are derived from parallel(...).
+ b().template create(
+ loc(), parallelLoopIndex, numThreads, procBind);
}
void KrnlBuilder::iterate(ValueRange originalLoops, ValueRange optimizedLoops,
diff --git a/src/Dialect/Krnl/DialectBuilder.hpp b/src/Dialect/Krnl/DialectBuilder.hpp
index 6a8d23097c..53f92bdb9a 100644
--- a/src/Dialect/Krnl/DialectBuilder.hpp
+++ b/src/Dialect/Krnl/DialectBuilder.hpp
@@ -66,6 +66,10 @@ struct KrnlBuilder : public DialectBuilder {
void permute(mlir::ValueRange loops, mlir::ArrayRef map) const;
mlir::ValueRange getInductionVarValue(mlir::ValueRange loops) const;
void parallel(mlir::ValueRange loops) const;
+ void parallel(mlir::ValueRange loops, mlir::Value numThreads,
+ mlir::StringAttr procBind) const;
+ void parallelClause(mlir::Value parallelLoopIndex, mlir::Value numThreads,
+ mlir::StringAttr procBind) const;
// Iterate over optimized loops given the original loops, lbs and ubs. Lambda
// function implement the body of the loop, and receive a KRNL builder and the
diff --git a/src/Dialect/Krnl/Krnl.td b/src/Dialect/Krnl/Krnl.td
index 1d89b46f1e..6f8bd48aed 100644
--- a/src/Dialect/Krnl/Krnl.td
+++ b/src/Dialect/Krnl/Krnl.td
@@ -514,7 +514,7 @@ def KrnlUnrollOp : Op {
}];
}
-def KrnlParallelOp : Op {
+def KrnlParallelOp : Op {
let summary = "Mark Krnl loops as parallel loops";
let description = [{
Parallelize the specified loops. When multiple loop specifiers are passed
@@ -522,15 +522,39 @@ def KrnlParallelOp : Op {
krnl.parallel should be placed as the last operator before krnl.iterate,
Since we do not want to parallelize the loop until we interpret krnl.block,
krnl.permute and krnl.unroll.
+
+ Optionally, a value may specifiy the number of threads requested for the
+ parallel loop. A proc_bind string may also be specified; valid values are
+ "primary", "close", or "spread". Default values are used when not specified.
+
```
krnl.parallel (%i0, %i1) : !Krnl.loop, !Krnl.loop
```
}];
- let arguments = (ins Variadic:$loops);
+ let arguments = (ins Variadic:$loops,
+ Optional:$num_threads,
+ OptionalAttr:$proc_bind);
let assemblyFormat = [{
- `(` $loops `)` attr-dict `:` type($loops)
+ `(` $loops `)` (`,` `num_threads` `(` $num_threads^ `)`)? attr-dict `:` type($loops)
+ }];
+}
+
+def KrnlParallelClauseOp : Op {
+ let summary = "Attach OpenMP clauses to an index varialbe";
+ let description = [{
+ Attach OpenMP clauses to an index variable. That index variable
+ is used to uniquely associate a parallel loop with its clauses.
+ }];
+
+ let arguments = (ins Index: $parallel_loop_index,
+ Optional:$num_threads,
+ OptionalAttr:$proc_bind);
+
+ let assemblyFormat = [{
+ `(` $parallel_loop_index `)` (`,` `num_threads` `(` $num_threads^ `)`)?
+ attr-dict `:` type($parallel_loop_index)
}];
}
diff --git a/src/Pass/Passes.hpp b/src/Pass/Passes.hpp
index 166a19217d..4261ee40fb 100644
--- a/src/Pass/Passes.hpp
+++ b/src/Pass/Passes.hpp
@@ -91,6 +91,7 @@ void configureOnnxToKrnlLoweringPass(bool reportOnParallel,
bool parallelIsEnabled, std::string specificParallelOps, bool reportOnSimd,
bool simdIsEnabled);
std::unique_ptr createProcessScfParallelPrivatePass();
+std::unique_ptr createProcessKrnlParallelClausePass();
#ifdef ONNX_MLIR_ENABLE_STABLEHLO
/// Add pass for lowering to Stablehlo IR.
diff --git a/src/Tools/onnx-mlir-opt/RegisterPasses.cpp b/src/Tools/onnx-mlir-opt/RegisterPasses.cpp
index b8285fcb69..720ee19ad9 100644
--- a/src/Tools/onnx-mlir-opt/RegisterPasses.cpp
+++ b/src/Tools/onnx-mlir-opt/RegisterPasses.cpp
@@ -97,6 +97,10 @@ void registerOMPasses(int optLevel) {
return createProcessScfParallelPrivatePass();
});
+ mlir::registerPass([]() -> std::unique_ptr {
+ return createProcessKrnlParallelClausePass();
+ });
+
mlir::registerPass([]() -> std::unique_ptr {
return krnl::createConvertSeqToMemrefPass();
});
diff --git a/src/Transform/CMakeLists.txt b/src/Transform/CMakeLists.txt
index 240f74b4e5..cc51752de0 100644
--- a/src/Transform/CMakeLists.txt
+++ b/src/Transform/CMakeLists.txt
@@ -8,12 +8,14 @@ add_onnx_mlir_library(OMLowerKrnlRegion
MLIRTransformUtils
)
- add_onnx_mlir_library(OMScfParallelPrivateRegion
+add_onnx_mlir_library(OMScfParallelPrivateRegion
ProcessScfParallelPrivate.cpp
+ ProcessKrnlParallelClause.cpp
LINK_LIBS PUBLIC
OMSupport
MLIRTransformUtils
+ MLIROpenMPToLLVM
)
add_onnx_mlir_library(OMInstrument
diff --git a/src/Transform/ProcessKrnlParallelClause.cpp b/src/Transform/ProcessKrnlParallelClause.cpp
new file mode 100644
index 0000000000..2f0d99329a
--- /dev/null
+++ b/src/Transform/ProcessKrnlParallelClause.cpp
@@ -0,0 +1,149 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+//===-- ProcessKrnlParallelClause.cpp - handle Krnl Parallel Clauses ------===//
+//
+// Copyright 2024 The IBM Research Authors.
+//
+// =============================================================================
+// This pass seeks KrnlParallelClauseOp and integrate its parameter in the
+// enclosing OpenMP Parallel construct.
+//
+//===----------------------------------------------------------------------===//
+
+#include "src/Transform/ProcessKrnlParallelClause.hpp"
+#include "src/Dialect/Krnl/KrnlOps.hpp"
+#include "src/Pass/Passes.hpp"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/Debug.h"
+
+#include "src/Support/TypeUtilities.hpp"
+
+#define DEBUG_TYPE "krnl-parallel-clause"
+
+using namespace mlir;
+
+namespace {
+
+struct ProcessKrnlParallelClauseWithoutScopePattern
+ : public OpRewritePattern {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(
+ KrnlParallelClauseOp clauseOp, PatternRewriter &rewriter) const final {
+ // Get Parallel Krnl Clause
+ Operation *op = clauseOp.getOperation();
+ Value numThreads = clauseOp.getNumThreads();
+ auto procBind = clauseOp.getProcBind();
+
+ Operation *parentParallelOp = op->getParentOp();
+ while (!llvm::dyn_cast_or_null(parentParallelOp))
+ parentParallelOp = parentParallelOp->getParentOp();
+
+ if (parentParallelOp) {
+ // Has an enclosing OpenMP parallel construct (expected).
+ LLVM_DEBUG(llvm::dbgs()
+ << "Have a KrnlParallelClause with its OMP Parallel op\n");
+ omp::ParallelOp parOp = llvm::cast(parentParallelOp);
+ if (numThreads) {
+ LLVM_DEBUG(llvm::dbgs() << " with a specific num_threads clause\n");
+ // Set the numbers of threads as indicated by clause op.
+ // WARNING: by moving the use of numThreads from inside the loop to the
+ // outer OpenMP parallel construct, we may potentially move the use of
+ // numThreads before its definition. However, because numThreads is by
+ // definition loop invariant, it is very unlikely that this case occurs.
+ // Nevertheless, this warning attests that this might be a possibility.
+ // In such case, we would get a compiler warning/error of use before
+ // def.
+ MutableOperandRange mutableNumThreads = parOp.getNumThreadsMutable();
+ mutableNumThreads.assign(numThreads);
+ }
+ if (procBind.has_value()) {
+ auto str = procBind.value().str();
+ LLVM_DEBUG(llvm::dbgs()
+ << " with a specific proc_bind clause: " << str << "\n");
+ // Set the affinity as indicated by the clause op.
+ if (str == "primary")
+ parOp.setProcBindKind(omp::ClauseProcBindKind::Primary);
+ else if (str == "close")
+ parOp.setProcBindKind(omp::ClauseProcBindKind::Close);
+ else if (str == "spread")
+ parOp.setProcBindKind(omp::ClauseProcBindKind::Spread);
+ else
+ llvm_unreachable("unkown proc_bind clause");
+ }
+ }
+ // Useful info from KrnlParallelClauseOp was extracted, remove now.
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+struct ProcessKrnlParallelClausePass
+ : public PassWrapper> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ProcessKrnlParallelClausePass)
+
+ ProcessKrnlParallelClausePass() {}
+ ProcessKrnlParallelClausePass(const ProcessKrnlParallelClausePass &pass)
+ : mlir::PassWrapper>() {}
+
+ StringRef getArgument() const override {
+ return "process-krnl-parallel-clause";
+ }
+
+ StringRef getDescription() const override {
+ return "Migrate info from Krnl Parallel Clause into OpenMP Parallel "
+ "operation.";
+ }
+
+ void runOnOperation() final;
+
+ typedef PassWrapper>
+ BaseType;
+};
+
+void ProcessKrnlParallelClausePass::runOnOperation() {
+ func::FuncOp function = getOperation();
+ MLIRContext *context = &getContext();
+
+ ConversionTarget target(getContext());
+ target.addLegalDialect();
+ // Op that is used and removed here.
+ target.addIllegalOp();
+
+ RewritePatternSet patterns(context);
+ onnx_mlir::getKrnlParallelClauseIntoOpenMPPatterns(patterns);
+
+ if (failed(applyPartialConversion(function, target, std::move(patterns))))
+ signalPassFailure();
+}
+
+} // namespace
+
+void onnx_mlir::getKrnlParallelClauseIntoOpenMPPatterns(
+ mlir::RewritePatternSet &patterns) {
+ MLIRContext *context = patterns.getContext();
+ patterns.insert(context);
+}
+
+/*!
+ * Create a Krnl Parallel Clause pass.
+ */
+std::unique_ptr onnx_mlir::createProcessKrnlParallelClausePass() {
+ return std::make_unique();
+}
diff --git a/src/Transform/ProcessKrnlParallelClause.hpp b/src/Transform/ProcessKrnlParallelClause.hpp
new file mode 100644
index 0000000000..7f5a7bc368
--- /dev/null
+++ b/src/Transform/ProcessKrnlParallelClause.hpp
@@ -0,0 +1,27 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+//===-- ProcessKrnlParallelClause.cpp - handle Krnl Parallel Clauses ------===//
+//
+// Copyright 2024 The IBM Research Authors.
+//
+// =============================================================================
+// This pass seeks KrnlParallelClauseOp and integrate its parameter in the
+// enclosing OpenMP Parallel construct.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef ONNX_MLIR_PROCESS_KRNL_PARALLEL_CLAUSE_H
+#define ONNX_MLIR_PROCESS_KRNL_PARALLEL_CLAUSE_H
+
+#include "mlir/IR/PatternMatch.h"
+
+namespace onnx_mlir {
+
+// Exports the patterns. They are all plain rewrite patterns that can be used
+// with any PatternRewriter, not conversion patterns.
+void getKrnlParallelClauseIntoOpenMPPatterns(mlir::RewritePatternSet &patterns);
+
+} // namespace onnx_mlir
+#endif
diff --git a/src/Transform/ProcessScfParallelPrivate.cpp b/src/Transform/ProcessScfParallelPrivate.cpp
index 998a138878..1996995a37 100644
--- a/src/Transform/ProcessScfParallelPrivate.cpp
+++ b/src/Transform/ProcessScfParallelPrivate.cpp
@@ -154,7 +154,7 @@ void onnx_mlir::getParallelPrivateScfToScfPatterns(
}
/*!
- * Create a RecomposeONNX pass.
+ * Create a SCF Parallel Private pass.
*/
std::unique_ptr onnx_mlir::createProcessScfParallelPrivatePass() {
return std::make_unique();
diff --git a/src/Transform/ProcessScfParallelPrivate.hpp b/src/Transform/ProcessScfParallelPrivate.hpp
index fe6428c92c..d9450eba47 100644
--- a/src/Transform/ProcessScfParallelPrivate.hpp
+++ b/src/Transform/ProcessScfParallelPrivate.hpp
@@ -20,8 +20,8 @@
namespace onnx_mlir {
-// Exports the RecomposeONNXToONNXPass patterns. They are all plain rewrite
-// patterns that can be used with any PatternRewriter, not conversion patterns.
+// Exports the patterns. They are all plain rewrite patterns that can be used
+// with any PatternRewriter, not conversion patterns.
void getParallelPrivateScfToScfPatterns(mlir::RewritePatternSet &patterns);
} // namespace onnx_mlir
diff --git a/test/mlir/conversion/krnl_to_affine/krnl_to_affine_parallel_clause.mlir b/test/mlir/conversion/krnl_to_affine/krnl_to_affine_parallel_clause.mlir
new file mode 100644
index 0000000000..1124514b79
--- /dev/null
+++ b/test/mlir/conversion/krnl_to_affine/krnl_to_affine_parallel_clause.mlir
@@ -0,0 +1,111 @@
+// RUN: onnx-mlir-opt -O3 --convert-krnl-to-affine --canonicalize %s -split-input-file | FileCheck %s
+
+// -----
+
+func.func @parallel_threads_affinity(%arg0: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) {
+ %c8_i32 = arith.constant 8 : i32
+ %c16384 = arith.constant 16384 : index
+ %alloc = memref.alloc() {alignment = 16 : i64} : memref<16x8x128xf32>
+ %alloc_0 = memref.alloc() {alignment = 16 : i64} : memref<1xindex>
+ affine.store %c16384, %alloc_0[0] : memref<1xindex>
+ %reshape = memref.reshape %arg0(%alloc_0) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32>
+ %alloc_1 = memref.alloc() {alignment = 16 : i64} : memref<1xindex>
+ affine.store %c16384, %alloc_1[0] : memref<1xindex>
+ %reshape_2 = memref.reshape %arg0(%alloc_1) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32>
+ %alloc_3 = memref.alloc() {alignment = 16 : i64} : memref<1xindex>
+ affine.store %c16384, %alloc_3[0] : memref<1xindex>
+ %reshape_4 = memref.reshape %alloc(%alloc_3) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32>
+ %0 = krnl.define_loops 1
+ %loop_block, %loop_local = krnl.block %0 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
+ krnl.parallel(%loop_block), num_threads(%c8_i32) {proc_bind = "spread"} : !krnl.loop
+ krnl.iterate(%loop_block) with (%0 -> %arg1 = 0 to 16384){
+ %1 = krnl.get_induction_var_value(%loop_block) : (!krnl.loop) -> index
+ %2 = vector.load %reshape[%1] : memref<16384xf32>, vector<32xf32>
+ %3 = vector.load %reshape_2[%1] : memref<16384xf32>, vector<32xf32>
+ %4 = arith.addf %2, %3 : vector<32xf32>
+ vector.store %4, %reshape_4[%1] : memref<16384xf32>, vector<32xf32>
+ }
+ return %alloc : memref<16x8x128xf32>
+
+// mlir2FileCheck.py
+// CHECK-LABEL: func.func @parallel_threads_affinity
+// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) attributes {llvm.emit_c_interface} {
+// CHECK: [[CST_8_:%.+]] = arith.constant 8 : i32
+// CHECK: affine.parallel ([[arg1_:%.+]]) = (0) to (16384) step (32) {
+// CHECK: krnl.parallel_clause([[arg1_]]), num_threads([[CST_8_]]) {proc_bind = "spread"} : index
+// CHECK: }
+// CHECK: }
+}
+
+// -----
+
+func.func @parallel_threads(%arg0: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) {
+ %c8_i32 = arith.constant 8 : i32
+ %c16384 = arith.constant 16384 : index
+ %alloc = memref.alloc() {alignment = 16 : i64} : memref<16x8x128xf32>
+ %alloc_0 = memref.alloc() {alignment = 16 : i64} : memref<1xindex>
+ affine.store %c16384, %alloc_0[0] : memref<1xindex>
+ %reshape = memref.reshape %arg0(%alloc_0) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32>
+ %alloc_1 = memref.alloc() {alignment = 16 : i64} : memref<1xindex>
+ affine.store %c16384, %alloc_1[0] : memref<1xindex>
+ %reshape_2 = memref.reshape %arg0(%alloc_1) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32>
+ %alloc_3 = memref.alloc() {alignment = 16 : i64} : memref<1xindex>
+ affine.store %c16384, %alloc_3[0] : memref<1xindex>
+ %reshape_4 = memref.reshape %alloc(%alloc_3) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32>
+ %0 = krnl.define_loops 1
+ %loop_block, %loop_local = krnl.block %0 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
+ krnl.parallel(%loop_block), num_threads(%c8_i32) : !krnl.loop
+ krnl.iterate(%loop_block) with (%0 -> %arg1 = 0 to 16384){
+ %1 = krnl.get_induction_var_value(%loop_block) : (!krnl.loop) -> index
+ %2 = vector.load %reshape[%1] : memref<16384xf32>, vector<32xf32>
+ %3 = vector.load %reshape_2[%1] : memref<16384xf32>, vector<32xf32>
+ %4 = arith.addf %2, %3 : vector<32xf32>
+ vector.store %4, %reshape_4[%1] : memref<16384xf32>, vector<32xf32>
+ }
+ return %alloc : memref<16x8x128xf32>
+
+// mlir2FileCheck.py
+// CHECK-LABEL: func.func @parallel_threads
+// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) attributes {llvm.emit_c_interface} {
+// CHECK: [[CST_8_:%.+]] = arith.constant 8 : i32
+// CHECK: affine.parallel ([[arg1_:%.+]]) = (0) to (16384) step (32) {
+// CHECK: krnl.parallel_clause([[arg1_]]), num_threads([[CST_8_]]) : index
+// CHECK: }
+// CHECK: }
+}
+
+// -----
+
+func.func @parallel_affinity(%arg0: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) {
+ %c8_i32 = arith.constant 8 : i32
+ %c16384 = arith.constant 16384 : index
+ %alloc = memref.alloc() {alignment = 16 : i64} : memref<16x8x128xf32>
+ %alloc_0 = memref.alloc() {alignment = 16 : i64} : memref<1xindex>
+ affine.store %c16384, %alloc_0[0] : memref<1xindex>
+ %reshape = memref.reshape %arg0(%alloc_0) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32>
+ %alloc_1 = memref.alloc() {alignment = 16 : i64} : memref<1xindex>
+ affine.store %c16384, %alloc_1[0] : memref<1xindex>
+ %reshape_2 = memref.reshape %arg0(%alloc_1) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32>
+ %alloc_3 = memref.alloc() {alignment = 16 : i64} : memref<1xindex>
+ affine.store %c16384, %alloc_3[0] : memref<1xindex>
+ %reshape_4 = memref.reshape %alloc(%alloc_3) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32>
+ %0 = krnl.define_loops 1
+ %loop_block, %loop_local = krnl.block %0 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
+ krnl.parallel(%loop_block) {proc_bind = "spread"} : !krnl.loop
+ krnl.iterate(%loop_block) with (%0 -> %arg1 = 0 to 16384){
+ %1 = krnl.get_induction_var_value(%loop_block) : (!krnl.loop) -> index
+ %2 = vector.load %reshape[%1] : memref<16384xf32>, vector<32xf32>
+ %3 = vector.load %reshape_2[%1] : memref<16384xf32>, vector<32xf32>
+ %4 = arith.addf %2, %3 : vector<32xf32>
+ vector.store %4, %reshape_4[%1] : memref<16384xf32>, vector<32xf32>
+ }
+ return %alloc : memref<16x8x128xf32>
+
+// mlir2FileCheck.py
+// CHECK-LABEL: func.func @parallel_affinity
+// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) attributes {llvm.emit_c_interface} {
+// CHECK: affine.parallel ([[arg1_:%.+]]) = (0) to (16384) step (32) {
+// CHECK: krnl.parallel_clause([[arg1_]]) {proc_bind = "spread"} : index
+// CHECK: }
+// CHECK: }
+}
diff --git a/test/mlir/parallel/krnl_parallel_clause_to_omp.mlir b/test/mlir/parallel/krnl_parallel_clause_to_omp.mlir
new file mode 100644
index 0000000000..dd88c9bd38
--- /dev/null
+++ b/test/mlir/parallel/krnl_parallel_clause_to_omp.mlir
@@ -0,0 +1,175 @@
+// RUN: onnx-mlir-opt -O3 --process-krnl-parallel-clause %s -split-input-file | FileCheck %s
+
+// -----
+
+func.func @omp_threads_affinity(%arg0: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) {
+ %c32 = arith.constant 32 : index
+ %c0 = arith.constant 0 : index
+ %c8_i32 = arith.constant 8 : i32
+ %c16384 = arith.constant 16384 : index
+ %alloc = memref.alloc() {alignment = 16 : i64} : memref<16x8x128xf32>
+ %alloc_0 = memref.alloc() {alignment = 16 : i64} : memref<1xindex>
+ memref.store %c16384, %alloc_0[%c0] : memref<1xindex>
+ %reshape = memref.reshape %arg0(%alloc_0) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32>
+ memref.dealloc %alloc_0 : memref<1xindex>
+ %alloc_1 = memref.alloc() {alignment = 16 : i64} : memref<1xindex>
+ memref.store %c16384, %alloc_1[%c0] : memref<1xindex>
+ %reshape_2 = memref.reshape %arg0(%alloc_1) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32>
+ memref.dealloc %alloc_1 : memref<1xindex>
+ %alloc_3 = memref.alloc() {alignment = 16 : i64} : memref<1xindex>
+ memref.store %c16384, %alloc_3[%c0] : memref<1xindex>
+ %reshape_4 = memref.reshape %alloc(%alloc_3) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32>
+ memref.dealloc %alloc_3 : memref<1xindex>
+ omp.parallel {
+ omp.wsloop {
+ omp.loop_nest (%arg1) : index = (%c0) to (%c16384) step (%c32) {
+ memref.alloca_scope {
+ %0 = vector.load %reshape[%arg1] : memref<16384xf32>, vector<32xf32>
+ %1 = vector.load %reshape_2[%arg1] : memref<16384xf32>, vector<32xf32>
+ %2 = arith.addf %0, %1 : vector<32xf32>
+ vector.store %2, %reshape_4[%arg1] : memref<16384xf32>, vector<32xf32>
+ krnl.parallel_clause(%arg1), num_threads(%c8_i32) {proc_bind = "spread"} : index
+ }
+ omp.yield
+ }
+ omp.terminator
+ }
+ omp.terminator
+ }
+ return %alloc : memref<16x8x128xf32>
+// mlir2FileCheck.py
+// CHECK-LABEL: func.func @omp_threads_affinity
+// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) {
+// CHECK: [[CST_8_:%.+]] = arith.constant 8 : i32
+// CHECK: omp.parallel num_threads([[CST_8_]] : i32) proc_bind(spread) {
+}
+
+// -----
+
+func.func @omp_threads(%arg0: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) {
+ %c32 = arith.constant 32 : index
+ %c0 = arith.constant 0 : index
+ %c8_i32 = arith.constant 8 : i32
+ %c16384 = arith.constant 16384 : index
+ %alloc = memref.alloc() {alignment = 16 : i64} : memref<16x8x128xf32>
+ %alloc_0 = memref.alloc() {alignment = 16 : i64} : memref<1xindex>
+ memref.store %c16384, %alloc_0[%c0] : memref<1xindex>
+ %reshape = memref.reshape %arg0(%alloc_0) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32>
+ memref.dealloc %alloc_0 : memref<1xindex>
+ %alloc_1 = memref.alloc() {alignment = 16 : i64} : memref<1xindex>
+ memref.store %c16384, %alloc_1[%c0] : memref<1xindex>
+ %reshape_2 = memref.reshape %arg0(%alloc_1) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32>
+ memref.dealloc %alloc_1 : memref<1xindex>
+ %alloc_3 = memref.alloc() {alignment = 16 : i64} : memref<1xindex>
+ memref.store %c16384, %alloc_3[%c0] : memref<1xindex>
+ %reshape_4 = memref.reshape %alloc(%alloc_3) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32>
+ memref.dealloc %alloc_3 : memref<1xindex>
+ omp.parallel {
+ omp.wsloop {
+ omp.loop_nest (%arg1) : index = (%c0) to (%c16384) step (%c32) {
+ memref.alloca_scope {
+ %0 = vector.load %reshape[%arg1] : memref<16384xf32>, vector<32xf32>
+ %1 = vector.load %reshape_2[%arg1] : memref<16384xf32>, vector<32xf32>
+ %2 = arith.addf %0, %1 : vector<32xf32>
+ vector.store %2, %reshape_4[%arg1] : memref<16384xf32>, vector<32xf32>
+ krnl.parallel_clause(%arg1), num_threads(%c8_i32) : index
+ }
+ omp.yield
+ }
+ omp.terminator
+ }
+ omp.terminator
+ }
+ return %alloc : memref<16x8x128xf32>
+// mlir2FileCheck.py
+// CHECK-LABEL: func.func @omp_threads
+// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) {
+// CHECK: [[CST_8_:%.+]] = arith.constant 8 : i32
+// CHECK: omp.parallel num_threads([[CST_8_]] : i32) {
+}
+
+// -----
+
+func.func @omp_affinity(%arg0: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) {
+ %c32 = arith.constant 32 : index
+ %c0 = arith.constant 0 : index
+ %c8_i32 = arith.constant 8 : i32
+ %c16384 = arith.constant 16384 : index
+ %alloc = memref.alloc() {alignment = 16 : i64} : memref<16x8x128xf32>
+ %alloc_0 = memref.alloc() {alignment = 16 : i64} : memref<1xindex>
+ memref.store %c16384, %alloc_0[%c0] : memref<1xindex>
+ %reshape = memref.reshape %arg0(%alloc_0) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32>
+ memref.dealloc %alloc_0 : memref<1xindex>
+ %alloc_1 = memref.alloc() {alignment = 16 : i64} : memref<1xindex>
+ memref.store %c16384, %alloc_1[%c0] : memref<1xindex>
+ %reshape_2 = memref.reshape %arg0(%alloc_1) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32>
+ memref.dealloc %alloc_1 : memref<1xindex>
+ %alloc_3 = memref.alloc() {alignment = 16 : i64} : memref<1xindex>
+ memref.store %c16384, %alloc_3[%c0] : memref<1xindex>
+ %reshape_4 = memref.reshape %alloc(%alloc_3) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32>
+ memref.dealloc %alloc_3 : memref<1xindex>
+ omp.parallel {
+ omp.wsloop {
+ omp.loop_nest (%arg1) : index = (%c0) to (%c16384) step (%c32) {
+ memref.alloca_scope {
+ %0 = vector.load %reshape[%arg1] : memref<16384xf32>, vector<32xf32>
+ %1 = vector.load %reshape_2[%arg1] : memref<16384xf32>, vector<32xf32>
+ %2 = arith.addf %0, %1 : vector<32xf32>
+ vector.store %2, %reshape_4[%arg1] : memref<16384xf32>, vector<32xf32>
+ krnl.parallel_clause(%arg1) {proc_bind = "spread"} : index
+ }
+ omp.yield
+ }
+ omp.terminator
+ }
+ omp.terminator
+ }
+ return %alloc : memref<16x8x128xf32>
+// mlir2FileCheck.py
+// CHECK-LABEL: func.func @omp_affinity
+// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) {
+// CHECK: omp.parallel proc_bind(spread) {
+}
+
+// -----
+
+func.func @omp_normal(%arg0: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) {
+ %c32 = arith.constant 32 : index
+ %c0 = arith.constant 0 : index
+ %c8_i32 = arith.constant 8 : i32
+ %c16384 = arith.constant 16384 : index
+ %alloc = memref.alloc() {alignment = 16 : i64} : memref<16x8x128xf32>
+ %alloc_0 = memref.alloc() {alignment = 16 : i64} : memref<1xindex>
+ memref.store %c16384, %alloc_0[%c0] : memref<1xindex>
+ %reshape = memref.reshape %arg0(%alloc_0) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32>
+ memref.dealloc %alloc_0 : memref<1xindex>
+ %alloc_1 = memref.alloc() {alignment = 16 : i64} : memref<1xindex>
+ memref.store %c16384, %alloc_1[%c0] : memref<1xindex>
+ %reshape_2 = memref.reshape %arg0(%alloc_1) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32>
+ memref.dealloc %alloc_1 : memref<1xindex>
+ %alloc_3 = memref.alloc() {alignment = 16 : i64} : memref<1xindex>
+ memref.store %c16384, %alloc_3[%c0] : memref<1xindex>
+ %reshape_4 = memref.reshape %alloc(%alloc_3) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32>
+ memref.dealloc %alloc_3 : memref<1xindex>
+ omp.parallel {
+ omp.wsloop {
+ omp.loop_nest (%arg1) : index = (%c0) to (%c16384) step (%c32) {
+ memref.alloca_scope {
+ %0 = vector.load %reshape[%arg1] : memref<16384xf32>, vector<32xf32>
+ %1 = vector.load %reshape_2[%arg1] : memref<16384xf32>, vector<32xf32>
+ %2 = arith.addf %0, %1 : vector<32xf32>
+ vector.store %2, %reshape_4[%arg1] : memref<16384xf32>, vector<32xf32>
+ }
+ omp.yield
+ }
+ omp.terminator
+ }
+ omp.terminator
+ }
+ return %alloc : memref<16x8x128xf32>
+// mlir2FileCheck.py
+// CHECK-LABEL: func.func @omp_normal
+// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) {
+// CHECK: omp.parallel {
+}
+