diff --git a/docs/Dialects/krnl.md b/docs/Dialects/krnl.md index c5c76450bd..80ad6224a9 100644 --- a/docs/Dialects/krnl.md +++ b/docs/Dialects/krnl.md @@ -929,6 +929,35 @@ Typically it is used for optional arguments used in KrnlCallop. | :----: | ----------- | | `none_val` | none type +### `krnl.parallel_clause` (KrnlParallelClauseOp) + +_Attach OpenMP clauses to an index varialbe_ + + +Syntax: + +``` +operation ::= `krnl.parallel_clause` `(` $parallel_loop_index `)` (`,` `num_threads` `(` $num_threads^ `)`)? + attr-dict `:` type($parallel_loop_index) +``` + +Attach OpenMP clauses to an index variable. That index variable +is used to uniquely associate a parallel loop with its clauses. + +#### Attributes: + + + + +
AttributeMLIR TypeDescription
proc_bind::mlir::StringAttrstring attribute
+ +#### Operands: + +| Operand | Description | +| :-----: | ----------- | +| `parallel_loop_index` | index +| `num_threads` | 32-bit signless integer + ### `krnl.parallel` (KrnlParallelOp) _Mark Krnl loops as parallel loops_ @@ -937,7 +966,7 @@ _Mark Krnl loops as parallel loops_ Syntax: ``` -operation ::= `krnl.parallel` `(` $loops `)` attr-dict `:` type($loops) +operation ::= `krnl.parallel` `(` $loops `)` (`,` `num_threads` `(` $num_threads^ `)`)? attr-dict `:` type($loops) ``` Parallelize the specified loops. When multiple loop specifiers are passed @@ -945,15 +974,30 @@ as parameters, there loops can be parallelized as a collapsed loop. krnl.parallel should be placed as the last operator before krnl.iterate, Since we do not want to parallelize the loop until we interpret krnl.block, krnl.permute and krnl.unroll. + +Optionally, a value may specifiy the number of threads requested for the +parallel loop. A proc_bind string may also be specified; valid values are +"primary", "close", or "spread". Default values are used when not specified. + ``` krnl.parallel (%i0, %i1) : !Krnl.loop, !Krnl.loop ``` +Traits: `AttrSizedOperandSegments` + +#### Attributes: + + + + +
AttributeMLIR TypeDescription
proc_bind::mlir::StringAttrstring attribute
+ #### Operands: | Operand | Description | | :-----: | ----------- | | `loops` | variadic of any type +| `num_threads` | 32-bit signless integer ### `krnl.permute` (KrnlPermuteOp) diff --git a/src/Compiler/CompilerPasses.cpp b/src/Compiler/CompilerPasses.cpp index c8f84e5565..2d796b8f34 100644 --- a/src/Compiler/CompilerPasses.cpp +++ b/src/Compiler/CompilerPasses.cpp @@ -251,6 +251,7 @@ void addKrnlToLLVMPasses( // The alloca_scope ops are somewhat fragile; canonicalize remove them when // redundant, which helps reliability of the compilation of these ops. pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(onnx_mlir::createProcessKrnlParallelClausePass()); } // The pass below is needed for subview and collapseShape.. Unfortunately, diff --git a/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp b/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp index cb68b58379..ee31243724 100644 --- a/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp +++ b/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp @@ -742,6 +742,10 @@ static LogicalResult interpretOperation(Operation *op, OpBuilder &builder, << parallelOp << "\n"); // ToFix handle multiple parallel loop ValueRange loopRefs = parallelOp.getLoops(); + Value numThreads = parallelOp.getNumThreads(); + StringAttr procBind = parallelOp.getProcBindAttr(); + bool needParallelClause = + numThreads || (procBind && procBind.getValue().size() > 0); // Obtain the the reference the loop that needs to be parallelized for (Value loopRef : loopRefs) { @@ -778,6 +782,23 @@ static LogicalResult interpretOperation(Operation *op, OpBuilder &builder, parallelLoop.getRegion().takeBody(loopToParallel.getRegion()); Operation *yieldOp = ¶llelLoop.getBody()->back(); yieldOp->setOperands(reducedValues); + if (needParallelClause) { + // Use clause only for the first one (expected the outermost one). + // Ideally, we would generate here a single, multi-dimensional + // AffineParallelOp, and we would not need to reset the flag. + needParallelClause = false; + // Currently approach: insert after yield and then move before it. + PatternRewriter::InsertionGuard insertGuard(builder); + builder.setInsertionPointAfter(yieldOp); + // Get induction variable. + ValueRange optionalLoopIndices = parallelLoop.getIVs(); + assert(optionalLoopIndices.size() >= 1 && + "expected at least one loop index"); + Value parallelLoopIndex = optionalLoopIndices[0]; + Operation *newOp = opBuilder.create( + loc, parallelLoopIndex, numThreads, procBind); + newOp->moveBefore(yieldOp); + } // Replace the affine.forOp with affine.parallelOp in loopRefToTop loopRefToOp[loopRef] = parallelLoop; loopToParallel.erase(); @@ -975,6 +996,7 @@ void ConvertKrnlToAffinePass::runOnOperation() { target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); diff --git a/src/Dialect/Krnl/DialectBuilder.cpp b/src/Dialect/Krnl/DialectBuilder.cpp index 541c3669b0..a645018de9 100644 --- a/src/Dialect/Krnl/DialectBuilder.cpp +++ b/src/Dialect/Krnl/DialectBuilder.cpp @@ -155,7 +155,26 @@ ValueRange KrnlBuilder::getInductionVarValue(ValueRange loops) const { } void KrnlBuilder::parallel(ValueRange loops) const { - b().template create(loc(), loops); + Value noneValue; + StringAttr noneStrAttr; + b().template create(loc(), loops, noneValue, noneStrAttr); +} + +void KrnlBuilder::parallel( + ValueRange loops, Value numThreads, StringAttr procBind) const { + if (procBind.getValue().size() > 0) { + std::string str = procBind.getValue().str(); + assert((str == "primary" || str == "close" || str == "spread") && + "expected primary, close, or spread for proc_bind"); + } + b().template create(loc(), loops, numThreads, procBind); +} + +void KrnlBuilder::parallelClause( + Value parallelLoopIndex, Value numThreads, StringAttr procBind) const { + // No need to check procBind as its value are derived from parallel(...). + b().template create( + loc(), parallelLoopIndex, numThreads, procBind); } void KrnlBuilder::iterate(ValueRange originalLoops, ValueRange optimizedLoops, diff --git a/src/Dialect/Krnl/DialectBuilder.hpp b/src/Dialect/Krnl/DialectBuilder.hpp index 6a8d23097c..53f92bdb9a 100644 --- a/src/Dialect/Krnl/DialectBuilder.hpp +++ b/src/Dialect/Krnl/DialectBuilder.hpp @@ -66,6 +66,10 @@ struct KrnlBuilder : public DialectBuilder { void permute(mlir::ValueRange loops, mlir::ArrayRef map) const; mlir::ValueRange getInductionVarValue(mlir::ValueRange loops) const; void parallel(mlir::ValueRange loops) const; + void parallel(mlir::ValueRange loops, mlir::Value numThreads, + mlir::StringAttr procBind) const; + void parallelClause(mlir::Value parallelLoopIndex, mlir::Value numThreads, + mlir::StringAttr procBind) const; // Iterate over optimized loops given the original loops, lbs and ubs. Lambda // function implement the body of the loop, and receive a KRNL builder and the diff --git a/src/Dialect/Krnl/Krnl.td b/src/Dialect/Krnl/Krnl.td index 1d89b46f1e..6f8bd48aed 100644 --- a/src/Dialect/Krnl/Krnl.td +++ b/src/Dialect/Krnl/Krnl.td @@ -514,7 +514,7 @@ def KrnlUnrollOp : Op { }]; } -def KrnlParallelOp : Op { +def KrnlParallelOp : Op { let summary = "Mark Krnl loops as parallel loops"; let description = [{ Parallelize the specified loops. When multiple loop specifiers are passed @@ -522,15 +522,39 @@ def KrnlParallelOp : Op { krnl.parallel should be placed as the last operator before krnl.iterate, Since we do not want to parallelize the loop until we interpret krnl.block, krnl.permute and krnl.unroll. + + Optionally, a value may specifiy the number of threads requested for the + parallel loop. A proc_bind string may also be specified; valid values are + "primary", "close", or "spread". Default values are used when not specified. + ``` krnl.parallel (%i0, %i1) : !Krnl.loop, !Krnl.loop ``` }]; - let arguments = (ins Variadic:$loops); + let arguments = (ins Variadic:$loops, + Optional:$num_threads, + OptionalAttr:$proc_bind); let assemblyFormat = [{ - `(` $loops `)` attr-dict `:` type($loops) + `(` $loops `)` (`,` `num_threads` `(` $num_threads^ `)`)? attr-dict `:` type($loops) + }]; +} + +def KrnlParallelClauseOp : Op { + let summary = "Attach OpenMP clauses to an index varialbe"; + let description = [{ + Attach OpenMP clauses to an index variable. That index variable + is used to uniquely associate a parallel loop with its clauses. + }]; + + let arguments = (ins Index: $parallel_loop_index, + Optional:$num_threads, + OptionalAttr:$proc_bind); + + let assemblyFormat = [{ + `(` $parallel_loop_index `)` (`,` `num_threads` `(` $num_threads^ `)`)? + attr-dict `:` type($parallel_loop_index) }]; } diff --git a/src/Pass/Passes.hpp b/src/Pass/Passes.hpp index 166a19217d..4261ee40fb 100644 --- a/src/Pass/Passes.hpp +++ b/src/Pass/Passes.hpp @@ -91,6 +91,7 @@ void configureOnnxToKrnlLoweringPass(bool reportOnParallel, bool parallelIsEnabled, std::string specificParallelOps, bool reportOnSimd, bool simdIsEnabled); std::unique_ptr createProcessScfParallelPrivatePass(); +std::unique_ptr createProcessKrnlParallelClausePass(); #ifdef ONNX_MLIR_ENABLE_STABLEHLO /// Add pass for lowering to Stablehlo IR. diff --git a/src/Tools/onnx-mlir-opt/RegisterPasses.cpp b/src/Tools/onnx-mlir-opt/RegisterPasses.cpp index b8285fcb69..720ee19ad9 100644 --- a/src/Tools/onnx-mlir-opt/RegisterPasses.cpp +++ b/src/Tools/onnx-mlir-opt/RegisterPasses.cpp @@ -97,6 +97,10 @@ void registerOMPasses(int optLevel) { return createProcessScfParallelPrivatePass(); }); + mlir::registerPass([]() -> std::unique_ptr { + return createProcessKrnlParallelClausePass(); + }); + mlir::registerPass([]() -> std::unique_ptr { return krnl::createConvertSeqToMemrefPass(); }); diff --git a/src/Transform/CMakeLists.txt b/src/Transform/CMakeLists.txt index 240f74b4e5..cc51752de0 100644 --- a/src/Transform/CMakeLists.txt +++ b/src/Transform/CMakeLists.txt @@ -8,12 +8,14 @@ add_onnx_mlir_library(OMLowerKrnlRegion MLIRTransformUtils ) - add_onnx_mlir_library(OMScfParallelPrivateRegion +add_onnx_mlir_library(OMScfParallelPrivateRegion ProcessScfParallelPrivate.cpp + ProcessKrnlParallelClause.cpp LINK_LIBS PUBLIC OMSupport MLIRTransformUtils + MLIROpenMPToLLVM ) add_onnx_mlir_library(OMInstrument diff --git a/src/Transform/ProcessKrnlParallelClause.cpp b/src/Transform/ProcessKrnlParallelClause.cpp new file mode 100644 index 0000000000..2f0d99329a --- /dev/null +++ b/src/Transform/ProcessKrnlParallelClause.cpp @@ -0,0 +1,149 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===-- ProcessKrnlParallelClause.cpp - handle Krnl Parallel Clauses ------===// +// +// Copyright 2024 The IBM Research Authors. +// +// ============================================================================= +// This pass seeks KrnlParallelClauseOp and integrate its parameter in the +// enclosing OpenMP Parallel construct. +// +//===----------------------------------------------------------------------===// + +#include "src/Transform/ProcessKrnlParallelClause.hpp" +#include "src/Dialect/Krnl/KrnlOps.hpp" +#include "src/Pass/Passes.hpp" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/Debug.h" + +#include "src/Support/TypeUtilities.hpp" + +#define DEBUG_TYPE "krnl-parallel-clause" + +using namespace mlir; + +namespace { + +struct ProcessKrnlParallelClauseWithoutScopePattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite( + KrnlParallelClauseOp clauseOp, PatternRewriter &rewriter) const final { + // Get Parallel Krnl Clause + Operation *op = clauseOp.getOperation(); + Value numThreads = clauseOp.getNumThreads(); + auto procBind = clauseOp.getProcBind(); + + Operation *parentParallelOp = op->getParentOp(); + while (!llvm::dyn_cast_or_null(parentParallelOp)) + parentParallelOp = parentParallelOp->getParentOp(); + + if (parentParallelOp) { + // Has an enclosing OpenMP parallel construct (expected). + LLVM_DEBUG(llvm::dbgs() + << "Have a KrnlParallelClause with its OMP Parallel op\n"); + omp::ParallelOp parOp = llvm::cast(parentParallelOp); + if (numThreads) { + LLVM_DEBUG(llvm::dbgs() << " with a specific num_threads clause\n"); + // Set the numbers of threads as indicated by clause op. + // WARNING: by moving the use of numThreads from inside the loop to the + // outer OpenMP parallel construct, we may potentially move the use of + // numThreads before its definition. However, because numThreads is by + // definition loop invariant, it is very unlikely that this case occurs. + // Nevertheless, this warning attests that this might be a possibility. + // In such case, we would get a compiler warning/error of use before + // def. + MutableOperandRange mutableNumThreads = parOp.getNumThreadsMutable(); + mutableNumThreads.assign(numThreads); + } + if (procBind.has_value()) { + auto str = procBind.value().str(); + LLVM_DEBUG(llvm::dbgs() + << " with a specific proc_bind clause: " << str << "\n"); + // Set the affinity as indicated by the clause op. + if (str == "primary") + parOp.setProcBindKind(omp::ClauseProcBindKind::Primary); + else if (str == "close") + parOp.setProcBindKind(omp::ClauseProcBindKind::Close); + else if (str == "spread") + parOp.setProcBindKind(omp::ClauseProcBindKind::Spread); + else + llvm_unreachable("unkown proc_bind clause"); + } + } + // Useful info from KrnlParallelClauseOp was extracted, remove now. + rewriter.eraseOp(op); + return success(); + } +}; + +struct ProcessKrnlParallelClausePass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ProcessKrnlParallelClausePass) + + ProcessKrnlParallelClausePass() {} + ProcessKrnlParallelClausePass(const ProcessKrnlParallelClausePass &pass) + : mlir::PassWrapper>() {} + + StringRef getArgument() const override { + return "process-krnl-parallel-clause"; + } + + StringRef getDescription() const override { + return "Migrate info from Krnl Parallel Clause into OpenMP Parallel " + "operation."; + } + + void runOnOperation() final; + + typedef PassWrapper> + BaseType; +}; + +void ProcessKrnlParallelClausePass::runOnOperation() { + func::FuncOp function = getOperation(); + MLIRContext *context = &getContext(); + + ConversionTarget target(getContext()); + target.addLegalDialect(); + // Op that is used and removed here. + target.addIllegalOp(); + + RewritePatternSet patterns(context); + onnx_mlir::getKrnlParallelClauseIntoOpenMPPatterns(patterns); + + if (failed(applyPartialConversion(function, target, std::move(patterns)))) + signalPassFailure(); +} + +} // namespace + +void onnx_mlir::getKrnlParallelClauseIntoOpenMPPatterns( + mlir::RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + patterns.insert(context); +} + +/*! + * Create a Krnl Parallel Clause pass. + */ +std::unique_ptr onnx_mlir::createProcessKrnlParallelClausePass() { + return std::make_unique(); +} diff --git a/src/Transform/ProcessKrnlParallelClause.hpp b/src/Transform/ProcessKrnlParallelClause.hpp new file mode 100644 index 0000000000..7f5a7bc368 --- /dev/null +++ b/src/Transform/ProcessKrnlParallelClause.hpp @@ -0,0 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===-- ProcessKrnlParallelClause.cpp - handle Krnl Parallel Clauses ------===// +// +// Copyright 2024 The IBM Research Authors. +// +// ============================================================================= +// This pass seeks KrnlParallelClauseOp and integrate its parameter in the +// enclosing OpenMP Parallel construct. +// +//===----------------------------------------------------------------------===// + +#ifndef ONNX_MLIR_PROCESS_KRNL_PARALLEL_CLAUSE_H +#define ONNX_MLIR_PROCESS_KRNL_PARALLEL_CLAUSE_H + +#include "mlir/IR/PatternMatch.h" + +namespace onnx_mlir { + +// Exports the patterns. They are all plain rewrite patterns that can be used +// with any PatternRewriter, not conversion patterns. +void getKrnlParallelClauseIntoOpenMPPatterns(mlir::RewritePatternSet &patterns); + +} // namespace onnx_mlir +#endif diff --git a/src/Transform/ProcessScfParallelPrivate.cpp b/src/Transform/ProcessScfParallelPrivate.cpp index 998a138878..1996995a37 100644 --- a/src/Transform/ProcessScfParallelPrivate.cpp +++ b/src/Transform/ProcessScfParallelPrivate.cpp @@ -154,7 +154,7 @@ void onnx_mlir::getParallelPrivateScfToScfPatterns( } /*! - * Create a RecomposeONNX pass. + * Create a SCF Parallel Private pass. */ std::unique_ptr onnx_mlir::createProcessScfParallelPrivatePass() { return std::make_unique(); diff --git a/src/Transform/ProcessScfParallelPrivate.hpp b/src/Transform/ProcessScfParallelPrivate.hpp index fe6428c92c..d9450eba47 100644 --- a/src/Transform/ProcessScfParallelPrivate.hpp +++ b/src/Transform/ProcessScfParallelPrivate.hpp @@ -20,8 +20,8 @@ namespace onnx_mlir { -// Exports the RecomposeONNXToONNXPass patterns. They are all plain rewrite -// patterns that can be used with any PatternRewriter, not conversion patterns. +// Exports the patterns. They are all plain rewrite patterns that can be used +// with any PatternRewriter, not conversion patterns. void getParallelPrivateScfToScfPatterns(mlir::RewritePatternSet &patterns); } // namespace onnx_mlir diff --git a/test/mlir/conversion/krnl_to_affine/krnl_to_affine_parallel_clause.mlir b/test/mlir/conversion/krnl_to_affine/krnl_to_affine_parallel_clause.mlir new file mode 100644 index 0000000000..1124514b79 --- /dev/null +++ b/test/mlir/conversion/krnl_to_affine/krnl_to_affine_parallel_clause.mlir @@ -0,0 +1,111 @@ +// RUN: onnx-mlir-opt -O3 --convert-krnl-to-affine --canonicalize %s -split-input-file | FileCheck %s + +// ----- + +func.func @parallel_threads_affinity(%arg0: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) { + %c8_i32 = arith.constant 8 : i32 + %c16384 = arith.constant 16384 : index + %alloc = memref.alloc() {alignment = 16 : i64} : memref<16x8x128xf32> + %alloc_0 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + affine.store %c16384, %alloc_0[0] : memref<1xindex> + %reshape = memref.reshape %arg0(%alloc_0) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + %alloc_1 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + affine.store %c16384, %alloc_1[0] : memref<1xindex> + %reshape_2 = memref.reshape %arg0(%alloc_1) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + %alloc_3 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + affine.store %c16384, %alloc_3[0] : memref<1xindex> + %reshape_4 = memref.reshape %alloc(%alloc_3) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + %0 = krnl.define_loops 1 + %loop_block, %loop_local = krnl.block %0 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) + krnl.parallel(%loop_block), num_threads(%c8_i32) {proc_bind = "spread"} : !krnl.loop + krnl.iterate(%loop_block) with (%0 -> %arg1 = 0 to 16384){ + %1 = krnl.get_induction_var_value(%loop_block) : (!krnl.loop) -> index + %2 = vector.load %reshape[%1] : memref<16384xf32>, vector<32xf32> + %3 = vector.load %reshape_2[%1] : memref<16384xf32>, vector<32xf32> + %4 = arith.addf %2, %3 : vector<32xf32> + vector.store %4, %reshape_4[%1] : memref<16384xf32>, vector<32xf32> + } + return %alloc : memref<16x8x128xf32> + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @parallel_threads_affinity +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) attributes {llvm.emit_c_interface} { +// CHECK: [[CST_8_:%.+]] = arith.constant 8 : i32 +// CHECK: affine.parallel ([[arg1_:%.+]]) = (0) to (16384) step (32) { +// CHECK: krnl.parallel_clause([[arg1_]]), num_threads([[CST_8_]]) {proc_bind = "spread"} : index +// CHECK: } +// CHECK: } +} + +// ----- + +func.func @parallel_threads(%arg0: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) { + %c8_i32 = arith.constant 8 : i32 + %c16384 = arith.constant 16384 : index + %alloc = memref.alloc() {alignment = 16 : i64} : memref<16x8x128xf32> + %alloc_0 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + affine.store %c16384, %alloc_0[0] : memref<1xindex> + %reshape = memref.reshape %arg0(%alloc_0) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + %alloc_1 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + affine.store %c16384, %alloc_1[0] : memref<1xindex> + %reshape_2 = memref.reshape %arg0(%alloc_1) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + %alloc_3 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + affine.store %c16384, %alloc_3[0] : memref<1xindex> + %reshape_4 = memref.reshape %alloc(%alloc_3) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + %0 = krnl.define_loops 1 + %loop_block, %loop_local = krnl.block %0 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) + krnl.parallel(%loop_block), num_threads(%c8_i32) : !krnl.loop + krnl.iterate(%loop_block) with (%0 -> %arg1 = 0 to 16384){ + %1 = krnl.get_induction_var_value(%loop_block) : (!krnl.loop) -> index + %2 = vector.load %reshape[%1] : memref<16384xf32>, vector<32xf32> + %3 = vector.load %reshape_2[%1] : memref<16384xf32>, vector<32xf32> + %4 = arith.addf %2, %3 : vector<32xf32> + vector.store %4, %reshape_4[%1] : memref<16384xf32>, vector<32xf32> + } + return %alloc : memref<16x8x128xf32> + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @parallel_threads +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) attributes {llvm.emit_c_interface} { +// CHECK: [[CST_8_:%.+]] = arith.constant 8 : i32 +// CHECK: affine.parallel ([[arg1_:%.+]]) = (0) to (16384) step (32) { +// CHECK: krnl.parallel_clause([[arg1_]]), num_threads([[CST_8_]]) : index +// CHECK: } +// CHECK: } +} + +// ----- + +func.func @parallel_affinity(%arg0: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) { + %c8_i32 = arith.constant 8 : i32 + %c16384 = arith.constant 16384 : index + %alloc = memref.alloc() {alignment = 16 : i64} : memref<16x8x128xf32> + %alloc_0 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + affine.store %c16384, %alloc_0[0] : memref<1xindex> + %reshape = memref.reshape %arg0(%alloc_0) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + %alloc_1 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + affine.store %c16384, %alloc_1[0] : memref<1xindex> + %reshape_2 = memref.reshape %arg0(%alloc_1) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + %alloc_3 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + affine.store %c16384, %alloc_3[0] : memref<1xindex> + %reshape_4 = memref.reshape %alloc(%alloc_3) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + %0 = krnl.define_loops 1 + %loop_block, %loop_local = krnl.block %0 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) + krnl.parallel(%loop_block) {proc_bind = "spread"} : !krnl.loop + krnl.iterate(%loop_block) with (%0 -> %arg1 = 0 to 16384){ + %1 = krnl.get_induction_var_value(%loop_block) : (!krnl.loop) -> index + %2 = vector.load %reshape[%1] : memref<16384xf32>, vector<32xf32> + %3 = vector.load %reshape_2[%1] : memref<16384xf32>, vector<32xf32> + %4 = arith.addf %2, %3 : vector<32xf32> + vector.store %4, %reshape_4[%1] : memref<16384xf32>, vector<32xf32> + } + return %alloc : memref<16x8x128xf32> + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @parallel_affinity +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) attributes {llvm.emit_c_interface} { +// CHECK: affine.parallel ([[arg1_:%.+]]) = (0) to (16384) step (32) { +// CHECK: krnl.parallel_clause([[arg1_]]) {proc_bind = "spread"} : index +// CHECK: } +// CHECK: } +} diff --git a/test/mlir/parallel/krnl_parallel_clause_to_omp.mlir b/test/mlir/parallel/krnl_parallel_clause_to_omp.mlir new file mode 100644 index 0000000000..dd88c9bd38 --- /dev/null +++ b/test/mlir/parallel/krnl_parallel_clause_to_omp.mlir @@ -0,0 +1,175 @@ +// RUN: onnx-mlir-opt -O3 --process-krnl-parallel-clause %s -split-input-file | FileCheck %s + +// ----- + +func.func @omp_threads_affinity(%arg0: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) { + %c32 = arith.constant 32 : index + %c0 = arith.constant 0 : index + %c8_i32 = arith.constant 8 : i32 + %c16384 = arith.constant 16384 : index + %alloc = memref.alloc() {alignment = 16 : i64} : memref<16x8x128xf32> + %alloc_0 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + memref.store %c16384, %alloc_0[%c0] : memref<1xindex> + %reshape = memref.reshape %arg0(%alloc_0) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + memref.dealloc %alloc_0 : memref<1xindex> + %alloc_1 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + memref.store %c16384, %alloc_1[%c0] : memref<1xindex> + %reshape_2 = memref.reshape %arg0(%alloc_1) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + memref.dealloc %alloc_1 : memref<1xindex> + %alloc_3 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + memref.store %c16384, %alloc_3[%c0] : memref<1xindex> + %reshape_4 = memref.reshape %alloc(%alloc_3) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + memref.dealloc %alloc_3 : memref<1xindex> + omp.parallel { + omp.wsloop { + omp.loop_nest (%arg1) : index = (%c0) to (%c16384) step (%c32) { + memref.alloca_scope { + %0 = vector.load %reshape[%arg1] : memref<16384xf32>, vector<32xf32> + %1 = vector.load %reshape_2[%arg1] : memref<16384xf32>, vector<32xf32> + %2 = arith.addf %0, %1 : vector<32xf32> + vector.store %2, %reshape_4[%arg1] : memref<16384xf32>, vector<32xf32> + krnl.parallel_clause(%arg1), num_threads(%c8_i32) {proc_bind = "spread"} : index + } + omp.yield + } + omp.terminator + } + omp.terminator + } + return %alloc : memref<16x8x128xf32> +// mlir2FileCheck.py +// CHECK-LABEL: func.func @omp_threads_affinity +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) { +// CHECK: [[CST_8_:%.+]] = arith.constant 8 : i32 +// CHECK: omp.parallel num_threads([[CST_8_]] : i32) proc_bind(spread) { +} + +// ----- + +func.func @omp_threads(%arg0: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) { + %c32 = arith.constant 32 : index + %c0 = arith.constant 0 : index + %c8_i32 = arith.constant 8 : i32 + %c16384 = arith.constant 16384 : index + %alloc = memref.alloc() {alignment = 16 : i64} : memref<16x8x128xf32> + %alloc_0 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + memref.store %c16384, %alloc_0[%c0] : memref<1xindex> + %reshape = memref.reshape %arg0(%alloc_0) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + memref.dealloc %alloc_0 : memref<1xindex> + %alloc_1 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + memref.store %c16384, %alloc_1[%c0] : memref<1xindex> + %reshape_2 = memref.reshape %arg0(%alloc_1) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + memref.dealloc %alloc_1 : memref<1xindex> + %alloc_3 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + memref.store %c16384, %alloc_3[%c0] : memref<1xindex> + %reshape_4 = memref.reshape %alloc(%alloc_3) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + memref.dealloc %alloc_3 : memref<1xindex> + omp.parallel { + omp.wsloop { + omp.loop_nest (%arg1) : index = (%c0) to (%c16384) step (%c32) { + memref.alloca_scope { + %0 = vector.load %reshape[%arg1] : memref<16384xf32>, vector<32xf32> + %1 = vector.load %reshape_2[%arg1] : memref<16384xf32>, vector<32xf32> + %2 = arith.addf %0, %1 : vector<32xf32> + vector.store %2, %reshape_4[%arg1] : memref<16384xf32>, vector<32xf32> + krnl.parallel_clause(%arg1), num_threads(%c8_i32) : index + } + omp.yield + } + omp.terminator + } + omp.terminator + } + return %alloc : memref<16x8x128xf32> +// mlir2FileCheck.py +// CHECK-LABEL: func.func @omp_threads +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) { +// CHECK: [[CST_8_:%.+]] = arith.constant 8 : i32 +// CHECK: omp.parallel num_threads([[CST_8_]] : i32) { +} + +// ----- + +func.func @omp_affinity(%arg0: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) { + %c32 = arith.constant 32 : index + %c0 = arith.constant 0 : index + %c8_i32 = arith.constant 8 : i32 + %c16384 = arith.constant 16384 : index + %alloc = memref.alloc() {alignment = 16 : i64} : memref<16x8x128xf32> + %alloc_0 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + memref.store %c16384, %alloc_0[%c0] : memref<1xindex> + %reshape = memref.reshape %arg0(%alloc_0) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + memref.dealloc %alloc_0 : memref<1xindex> + %alloc_1 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + memref.store %c16384, %alloc_1[%c0] : memref<1xindex> + %reshape_2 = memref.reshape %arg0(%alloc_1) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + memref.dealloc %alloc_1 : memref<1xindex> + %alloc_3 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + memref.store %c16384, %alloc_3[%c0] : memref<1xindex> + %reshape_4 = memref.reshape %alloc(%alloc_3) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + memref.dealloc %alloc_3 : memref<1xindex> + omp.parallel { + omp.wsloop { + omp.loop_nest (%arg1) : index = (%c0) to (%c16384) step (%c32) { + memref.alloca_scope { + %0 = vector.load %reshape[%arg1] : memref<16384xf32>, vector<32xf32> + %1 = vector.load %reshape_2[%arg1] : memref<16384xf32>, vector<32xf32> + %2 = arith.addf %0, %1 : vector<32xf32> + vector.store %2, %reshape_4[%arg1] : memref<16384xf32>, vector<32xf32> + krnl.parallel_clause(%arg1) {proc_bind = "spread"} : index + } + omp.yield + } + omp.terminator + } + omp.terminator + } + return %alloc : memref<16x8x128xf32> +// mlir2FileCheck.py +// CHECK-LABEL: func.func @omp_affinity +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) { +// CHECK: omp.parallel proc_bind(spread) { +} + +// ----- + +func.func @omp_normal(%arg0: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) { + %c32 = arith.constant 32 : index + %c0 = arith.constant 0 : index + %c8_i32 = arith.constant 8 : i32 + %c16384 = arith.constant 16384 : index + %alloc = memref.alloc() {alignment = 16 : i64} : memref<16x8x128xf32> + %alloc_0 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + memref.store %c16384, %alloc_0[%c0] : memref<1xindex> + %reshape = memref.reshape %arg0(%alloc_0) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + memref.dealloc %alloc_0 : memref<1xindex> + %alloc_1 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + memref.store %c16384, %alloc_1[%c0] : memref<1xindex> + %reshape_2 = memref.reshape %arg0(%alloc_1) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + memref.dealloc %alloc_1 : memref<1xindex> + %alloc_3 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + memref.store %c16384, %alloc_3[%c0] : memref<1xindex> + %reshape_4 = memref.reshape %alloc(%alloc_3) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + memref.dealloc %alloc_3 : memref<1xindex> + omp.parallel { + omp.wsloop { + omp.loop_nest (%arg1) : index = (%c0) to (%c16384) step (%c32) { + memref.alloca_scope { + %0 = vector.load %reshape[%arg1] : memref<16384xf32>, vector<32xf32> + %1 = vector.load %reshape_2[%arg1] : memref<16384xf32>, vector<32xf32> + %2 = arith.addf %0, %1 : vector<32xf32> + vector.store %2, %reshape_4[%arg1] : memref<16384xf32>, vector<32xf32> + } + omp.yield + } + omp.terminator + } + omp.terminator + } + return %alloc : memref<16x8x128xf32> +// mlir2FileCheck.py +// CHECK-LABEL: func.func @omp_normal +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) { +// CHECK: omp.parallel { +} +