Skip to content

Commit

Permalink
Merge branch 'main' into mem_reduction_stickified
Browse files Browse the repository at this point in the history
  • Loading branch information
imaihal committed Sep 20, 2024
2 parents 53b99c1 + bf905d1 commit 1d4ed1b
Show file tree
Hide file tree
Showing 33 changed files with 2,424 additions and 859 deletions.
46 changes: 45 additions & 1 deletion docs/Dialects/krnl.md
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,35 @@ Typically it is used for optional arguments used in KrnlCallop.
| :----: | ----------- |
| `none_val` | none type

### `krnl.parallel_clause` (KrnlParallelClauseOp)

_Attach OpenMP clauses to an index varialbe_


Syntax:

```
operation ::= `krnl.parallel_clause` `(` $parallel_loop_index `)` (`,` `num_threads` `(` $num_threads^ `)`)?
attr-dict `:` type($parallel_loop_index)
```

Attach OpenMP clauses to an index variable. That index variable
is used to uniquely associate a parallel loop with its clauses.

#### Attributes:

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>proc_bind</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
</table>

#### Operands:

| Operand | Description |
| :-----: | ----------- |
| `parallel_loop_index` | index
| `num_threads` | 32-bit signless integer

### `krnl.parallel` (KrnlParallelOp)

_Mark Krnl loops as parallel loops_
Expand All @@ -937,23 +966,38 @@ _Mark Krnl loops as parallel loops_
Syntax:

```
operation ::= `krnl.parallel` `(` $loops `)` attr-dict `:` type($loops)
operation ::= `krnl.parallel` `(` $loops `)` (`,` `num_threads` `(` $num_threads^ `)`)? attr-dict `:` type($loops)
```

Parallelize the specified loops. When multiple loop specifiers are passed
as parameters, there loops can be parallelized as a collapsed loop.
krnl.parallel should be placed as the last operator before krnl.iterate,
Since we do not want to parallelize the loop until we interpret krnl.block,
krnl.permute and krnl.unroll.

Optionally, a value may specifiy the number of threads requested for the
parallel loop. A proc_bind string may also be specified; valid values are
"primary", "close", or "spread". Default values are used when not specified.

```
krnl.parallel (%i0, %i1) : !Krnl.loop, !Krnl.loop
```

Traits: `AttrSizedOperandSegments`

#### Attributes:

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>proc_bind</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
</table>

#### Operands:

| Operand | Description |
| :-----: | ----------- |
| `loops` | variadic of any type
| `num_threads` | 32-bit signless integer

### `krnl.permute` (KrnlPermuteOp)

Expand Down
6 changes: 3 additions & 3 deletions src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
// Store f32 values back to the (normal layout) output.
DimsExpr outputAF = SymListIE(inputAF);
outputAF[E1] = outputAF[E1] + l;
create.vec.storeIE(vecF32H, alloc, outputAF, {});
create.vec.storeIE(vecF32H, alloc, outputAF);
create.vec.storeIE(
vecF32L, alloc, outputAF, {litArchVLHalf.getValue()});
});
Expand All @@ -277,8 +277,8 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
Value vecF32L = convertOp.getResult(1);
// Save into archVL value buffer.
Value bufferF32 = create.mem.alignedAlloca(bufferType);
create.vec.storeIE(vecF32H, bufferF32, {litZero}, {});
create.vec.storeIE(vecF32L, bufferF32, {litArchVLHalf}, {});
create.vec.storeIE(vecF32H, bufferF32, {litZero});
create.vec.storeIE(vecF32L, bufferF32, {litArchVLHalf});
// Save the remaining values as scalars.
create.scf.forLoop(litZero.getValue(),
remainingScalarValues.getValue(), 1,
Expand Down
17 changes: 13 additions & 4 deletions src/Compiler/CompilerOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ bool enableONNXHybridPass; // common for both
std::vector<std::string> functionsToDecompose; // common for both
std::string opsForCall; // common for both
bool disableKrnlOpFusion; // common for both
bool disableQuantZeroPoint; // common for both
bool enableKrnlBufferReuse; // common for both
bool disableMemRefPrefetch; // common for both
EmissionTargetType emissionTarget; // onnx-mlir only
Expand Down Expand Up @@ -195,7 +196,7 @@ static llvm::cl::list<std::string, std::vector<std::string>>
llvm::cl::cat(OnnxMlirCommonOptions));

static llvm::cl::opt<bool, true> enableONNXHybridPassOpt("onnx-hybrid-pass",
llvm::cl::desc("Enable ONNX hybrid pass (default=true)\n"
llvm::cl::desc("Enable ONNX hybrid pass (default=true).\n"
"Set to 'false' if you want to disable ONNX hybrid pass."),
llvm::cl::location(enableONNXHybridPass), llvm::cl::init(true),
llvm::cl::cat(OnnxMlirCommonOptions));
Expand All @@ -208,11 +209,20 @@ static llvm::cl::list<std::string, std::vector<std::string>>

static llvm::cl::opt<bool, true> disableKrnlOpFusionOpt(
"disable-krnl-op-fusion",
llvm::cl::desc("disable op fusion in onnx-to-krnl pass (default=false)\n"
llvm::cl::desc("Disable op fusion in onnx-to-krnl pass (default=false).\n"
"Set to 'true' if you want to disable fusion."),
llvm::cl::location(disableKrnlOpFusion), llvm::cl::init(false),
llvm::cl::cat(OnnxMlirCommonOptions));

static llvm::cl::opt<bool, true> disable_quantization_zero_point(
"disable-quantization-zero-point",
llvm::cl::desc(
"Disable the use of zero-point in quantization (default=false).\n"
"Set to 'true' if you want to disable the use of zero-point\n"
"in dyn/static quantization/dequantization."),
llvm::cl::location(disableQuantZeroPoint), llvm::cl::init(false),
llvm::cl::cat(OnnxMlirCommonOptions));

static llvm::cl::opt<bool, true> enableKrnlBufferReuseOpt(
"enable-krnl-buffer-reuse",
llvm::cl::desc("enable buffer reuse within an op in onnx-to-krnl pass"
Expand All @@ -223,7 +233,7 @@ static llvm::cl::opt<bool, true> enableKrnlBufferReuseOpt(

static llvm::cl::opt<bool, true> disableMemRefPrefetchOpt(
"disable-memref-prefetch",
llvm::cl::desc("disable generation of memref.prefetch (default=false)\n"
llvm::cl::desc("Disable generation of memref.prefetch (default=false).\n"
"Set to 'true' if you want to disable prefetch."),
llvm::cl::location(disableMemRefPrefetch), llvm::cl::init(false),
llvm::cl::cat(OnnxMlirCommonOptions));
Expand Down Expand Up @@ -1145,7 +1155,6 @@ std::string getLibraryPath() {
// as lrodataScript.
std::string getToolPath(
const std::string &tool, bool flag /*false by default*/) {

if (!flag) {
std::string execDir = llvm::sys::path::parent_path(getExecPath()).str();
llvm::SmallString<8> toolPath(execDir);
Expand Down
1 change: 1 addition & 0 deletions src/Compiler/CompilerOptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ extern bool enableONNXHybridPass; // common for both
extern std::vector<std::string> functionsToDecompose; // common for both
extern std::string opsForCall; // common for both
extern bool disableKrnlOpFusion; // common for both
extern bool disableQuantZeroPoint; // common for both
extern bool enableKrnlBufferReuse; // common for both
extern bool disableMemRefPrefetch; // common for both
extern EmissionTargetType emissionTarget; // onnx-mlir only
Expand Down
1 change: 1 addition & 0 deletions src/Compiler/CompilerPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ void addKrnlToLLVMPasses(
// The alloca_scope ops are somewhat fragile; canonicalize remove them when
// redundant, which helps reliability of the compilation of these ops.
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(onnx_mlir::createProcessKrnlParallelClausePass());
}

// The pass below is needed for subview and collapseShape.. Unfortunately,
Expand Down
22 changes: 22 additions & 0 deletions src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,10 @@ static LogicalResult interpretOperation(Operation *op, OpBuilder &builder,
<< parallelOp << "\n");
// ToFix handle multiple parallel loop
ValueRange loopRefs = parallelOp.getLoops();
Value numThreads = parallelOp.getNumThreads();
StringAttr procBind = parallelOp.getProcBindAttr();
bool needParallelClause =
numThreads || (procBind && procBind.getValue().size() > 0);

// Obtain the the reference the loop that needs to be parallelized
for (Value loopRef : loopRefs) {
Expand Down Expand Up @@ -778,6 +782,23 @@ static LogicalResult interpretOperation(Operation *op, OpBuilder &builder,
parallelLoop.getRegion().takeBody(loopToParallel.getRegion());
Operation *yieldOp = &parallelLoop.getBody()->back();
yieldOp->setOperands(reducedValues);
if (needParallelClause) {
// Use clause only for the first one (expected the outermost one).
// Ideally, we would generate here a single, multi-dimensional
// AffineParallelOp, and we would not need to reset the flag.
needParallelClause = false;
// Currently approach: insert after yield and then move before it.
PatternRewriter::InsertionGuard insertGuard(builder);
builder.setInsertionPointAfter(yieldOp);
// Get induction variable.
ValueRange optionalLoopIndices = parallelLoop.getIVs();
assert(optionalLoopIndices.size() >= 1 &&
"expected at least one loop index");
Value parallelLoopIndex = optionalLoopIndices[0];
Operation *newOp = opBuilder.create<KrnlParallelClauseOp>(
loc, parallelLoopIndex, numThreads, procBind);
newOp->moveBefore(yieldOp);
}
// Replace the affine.forOp with affine.parallelOp in loopRefToTop
loopRefToOp[loopRef] = parallelLoop;
loopToParallel.erase();
Expand Down Expand Up @@ -975,6 +996,7 @@ void ConvertKrnlToAffinePass::runOnOperation() {
target.addIllegalOp<KrnlCopyToBufferOp>();
target.addIllegalOp<KrnlCopyFromBufferOp>();
target.addIllegalOp<KrnlPrefetchOp>();
target.addLegalOp<KrnlParallelClauseOp>();
target.addLegalOp<AffineYieldOp>();
target.addLegalOp<AffineLoadOp>();
target.addLegalOp<AffineStoreOp>();
Expand Down
2 changes: 1 addition & 1 deletion src/Conversion/KrnlToAffine/KrnlCopyFromBuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class KrnlCopyFromBufferLowering : public ConversionPattern {
// Nothing to write.
} else {
// Loop to copy the data.
createAffine.forLoopIE(zeroIE, writeUBs[i], 1,
createAffine.forLoopIE(zeroIE, writeUBs[i], 1, false /*parallel*/,
[&](AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
loopIndices.emplace_back(loopInd[0]);
genCopyLoops(createAffine, enclosingScope, buffMemref, destMemref,
Expand Down
19 changes: 12 additions & 7 deletions src/Conversion/ONNXToKrnl/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1358,9 +1358,15 @@ Value emitScalarOpFor<ONNXDequantizeLinearOp>(
Value scaleFloat = scalarOperands[1];
Value zeroPointInt = scalarOperands[2];

Value zeroPointFloat = create.math.cast(elementType, zeroPointInt);
Value xFloat = create.math.cast(elementType, XInt);
Value sub = create.math.sub(xFloat, zeroPointFloat);

Value sub;
if (!disableQuantZeroPoint && !isNoneValue(zeroPointInt)) {
Value zeroPointFloat = create.math.cast(elementType, zeroPointInt);
sub = create.math.sub(xFloat, zeroPointFloat);
} else {
sub = xFloat;
}
Value res = create.math.mul(sub, scaleFloat);
return res;
}
Expand Down Expand Up @@ -1521,8 +1527,7 @@ static LogicalResult getPartiallyFlattenedSimdCode(

create.krnl.simdIterateIE(zero, SymIE(simdUb), VL, simdOnly,
useParallelInSimdLoop, inputs, inputAFs, {output}, {outputAF},
[&](KrnlBuilder &kb, ArrayRef<Value> inputVals,
SmallVectorImpl<Value> &resVals, int64_t VL) {
{[&](const KrnlBuilder &kb, ArrayRef<Value> inputVals, int64_t VL) {
MultiDialectBuilder<MathBuilder> create(kb);
Type currElementType = outputElementType;
if (VL > 1)
Expand Down Expand Up @@ -1551,9 +1556,9 @@ static LogicalResult getPartiallyFlattenedSimdCode(
res = emitPostProcessingFor<OP_TYPE>(rewriter, create.getLoc(),
op, currElementType, accumulated);
}
resVals.emplace_back(res);
}); // SIMD kernel.
}); // Outer loops.
return res;
}}); // SIMD kernel.
}); // Outer loops.

rewriter.replaceOp(op, alloc);
return success();
Expand Down
Loading

0 comments on commit 1d4ed1b

Please sign in to comment.