Skip to content

Commit

Permalink
[LinalgFunctionOutlining] Outline all ops
Browse files Browse the repository at this point in the history
  • Loading branch information
jtuyls committed Jan 28, 2025
1 parent 9f5d355 commit 1fe74b7
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ static FailureOr<func::FuncOp> outline(IRRewriter &rewriter, ModuleOp moduleOp,
// clang-format off
// https://github.com/llvm/llvm-project/blob/6b0785390d02193d81d8db7fb12279ffa4651afe/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td#L475
// clang-format on
auto type = dyn_cast<MemRefType>(operand.getType());
assert(type && "we've already checked that all operands are memrefs");
MemRefLayoutAttrInterface layout = type.getLayout();
assert(layout &&
if (auto type = dyn_cast<MemRefType>(operand.getType())) {
MemRefLayoutAttrInterface layout = type.getLayout();
assert(layout &&
"MemRefType layout attribute interface should always be present");
if (!layout.isIdentity()) return failure();
if (!layout.isIdentity()) return failure();
}
}
auto funcType = FunctionType::get(
rewriter.getContext(), computeOp->getOperandTypes(), /*outputTypes=*/{});
Expand Down Expand Up @@ -74,13 +74,6 @@ static FailureOr<func::FuncOp> outline(IRRewriter &rewriter, ModuleOp moduleOp,

/// Utility to check if the linalg op is one we know should be outlined.
static bool mustOutline(linalg::LinalgOp linalgOp) {
if (isa<linalg::CopyOp, linalg::FillOp>(linalgOp)) return false;
if (isElementwise(linalgOp)) return false;
// TODO(newling) not all remaining ops should be outlined, not even all
// remaining matmuls: below some threshold on size (m*n*k) it's not worth
// outlining (function call overhead). We should extend the set of ops that
// are not outlined here.

return true;
};

Expand Down Expand Up @@ -162,14 +155,6 @@ void AMDAIELinalgFunctionOutliningPass::runOnOperation() {
moduleOp.walk([&](linalg::LinalgOp computeOp) {
if (!mustOutline(computeOp)) return WalkResult::skip();

// Assert that we're in reference semantics, ie that all operands of
// computeOp have MemRefType:
if (!llvm::all_of(computeOp->getOperandTypes(),
[](Type t) { return isa<MemRefType>(t); })) {
computeOp->emitError("expected all operands to be of MemRefType");
return WalkResult::interrupt();
}

FailureOr<func::FuncOp> maybeFunc =
retrieveOrCreate(rewriter, moduleOp, computeOp);
if (failed(maybeFunc)) return WalkResult::interrupt();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,20 @@ func.func @distinct_matmul_shapes(%A0: memref<4x4xbf16>, %B0: memref<4x4xbf16>,

// -----

// CHECK-LABEL: @linalg_fill_copy
// CHECK-DAG: func.func private @[[FILL_FUNC:.*]]({{.*}}: f32, {{.*}}: memref<4xf32>)
// CHECK-DAG: linalg.fill
// CHECK-DAG: func.func private @[[COPY_FUNC:.*]]({{.*}}: memref<4xf32>, {{.*}}: memref<4xf32>)
// CHECK-DAG: linalg.copy
// CHECK: @linalg_fill_copy(%[[ARG0:.*]]: memref<4xf32>, %[[ARG1:.*]]: memref<4xf32>)
// CHECK: %[[C0:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: func.call @[[FILL_FUNC]](%[[C0]], %[[ARG0]])
// CHECK: func.call @[[COPY_FUNC]](%[[ARG0]], %[[ARG1]])
func.func @linalg_fill_copy(%A: memref<4xf32>, %B: memref<4xf32>) {
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%cst = arith.constant 0.0 : f32
%tile = amdaie.tile(%c1, %c2)
%0 = amdaie.core(%tile, in : [], out : []) {
// CHECK: linalg.fill
// CHECK-NOT: func.call @fill_elementwise_0_outlined
// CHECK: linalg.copy
// CHECK-NOT: func.call @copy_elementwise_1_outlined
linalg.fill ins(%cst : f32) outs(%A : memref<4xf32>)
linalg.copy ins(%A : memref<4xf32>) outs(%B : memref<4xf32>)
amdaie.end
Expand Down

0 comments on commit 1fe74b7

Please sign in to comment.