Skip to content

Commit

Permalink
[mlir][bufferization][scf] Implement BufferDeallocationOpInterface fo…
Browse files Browse the repository at this point in the history
…r scf.reduce.return (#66886)

This is necessary to run the new buffer deallocation pipeline as part of
the sparse compiler pipeline.
  • Loading branch information
maerhart authored Sep 20, 2023
1 parent 57a5548 commit ba727ac
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,27 @@ struct InParallelOpInterface
}
};

struct ReduceReturnOpInterface
: public BufferDeallocationOpInterface::ExternalModel<
ReduceReturnOpInterface, scf::ReduceReturnOp> {
FailureOr<Operation *> process(Operation *op, DeallocationState &state,
const DeallocationOptions &options) const {
auto reduceReturnOp = cast<scf::ReduceReturnOp>(op);
if (isa<BaseMemRefType>(reduceReturnOp.getOperand().getType()))
return op->emitError("only supported when operand is not a MemRef");

SmallVector<Value> updatedOperandOwnership;
return deallocation_impl::insertDeallocOpForReturnLike(
state, op, {}, updatedOperandOwnership);
}
};

} // namespace

void mlir::scf::registerBufferDeallocationOpInterfaceExternalModels(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, SCFDialect *dialect) {
InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
ReduceReturnOp::attachInterface<ReduceReturnOpInterface>(*ctx);
});
}
28 changes: 28 additions & 0 deletions mlir/test/Dialect/SCF/buffer-deallocation.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,31 @@ func.func @parallel_insert_slice(%arg0: index) {
// CHECK: }
// CHECK: bufferization.dealloc ([[ALLOC0]] : memref<2xf32>) if (%true
// CHECK-NOT: retain

// -----

func.func @reduce(%buffer: memref<100xf32>) {
%init = arith.constant 0.0 : f32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.parallel (%iv) = (%c0) to (%c1) step (%c1) init (%init) -> f32 {
%elem_to_reduce = memref.load %buffer[%iv] : memref<100xf32>
scf.reduce(%elem_to_reduce) : f32 {
^bb0(%lhs : f32, %rhs: f32):
%alloc = memref.alloc() : memref<2xf32>
memref.store %lhs, %alloc [%c0] : memref<2xf32>
memref.store %rhs, %alloc [%c1] : memref<2xf32>
%0 = memref.load %alloc[%c0] : memref<2xf32>
%1 = memref.load %alloc[%c1] : memref<2xf32>
%res = arith.addf %0, %1 : f32
scf.reduce.return %res : f32
}
}
func.return
}

// CHECK-LABEL: func @reduce
// CHECK: scf.reduce
// CHECK: [[ALLOC:%.+]] = memref.alloc(
// CHECK: bufferization.dealloc ([[ALLOC]] :{{.*}}) if (%true
// CHECK: scf.reduce.return

0 comments on commit ba727ac

Please sign in to comment.