Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Jun 16, 2024
1 parent 27ed434 commit 0b2030a
Show file tree
Hide file tree
Showing 8 changed files with 299 additions and 496 deletions.
298 changes: 242 additions & 56 deletions compiler/plugins/target/AMD-AIE/aie/AIEDmaToNpu.cpp

Large diffs are not rendered by default.

208 changes: 34 additions & 174 deletions compiler/plugins/target/AMD-AIE/aie/AIEPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@
#define DEBUG_TYPE "aie-pass"

using namespace mlir;
using namespace xilinx;
using namespace mlir::vector;
using namespace xilinx::AIE;
using namespace xilinx::AIEX;

const std::map<xilinx::AIE::WireBundle, StrmSwPortType>
_WIRE_BUNDLE_TO_STRM_SW_PORT_TYPE = {
Expand Down Expand Up @@ -88,59 +89,46 @@ xilinx::AIE::WireBundle STRM_SW_PORT_TYPE_TO_WIRE_BUNDLE(StrmSwPortType s) {

template <typename DerivedT>
class AIEAssignBufferAddressesPassBasicBase
: public ::mlir::OperationPass<DeviceOp> {
: public mlir::OperationPass<DeviceOp> {
public:
using Base = AIEAssignBufferAddressesPassBasicBase;

AIEAssignBufferAddressesPassBasicBase()
: ::mlir::OperationPass<DeviceOp>(::mlir::TypeID::get<DerivedT>()) {}
: mlir::OperationPass<DeviceOp>(mlir::TypeID::get<DerivedT>()) {}
AIEAssignBufferAddressesPassBasicBase(
const AIEAssignBufferAddressesPassBasicBase &other)
: ::mlir::OperationPass<DeviceOp>(other) {}
AIEAssignBufferAddressesPassBasicBase &operator=(
const AIEAssignBufferAddressesPassBasicBase &) = delete;
AIEAssignBufferAddressesPassBasicBase(
AIEAssignBufferAddressesPassBasicBase &&) = delete;
AIEAssignBufferAddressesPassBasicBase &operator=(
AIEAssignBufferAddressesPassBasicBase &&) = delete;
~AIEAssignBufferAddressesPassBasicBase() = default;
: mlir::OperationPass<DeviceOp>(other) {}

/// Returns the command-line argument attached to this pass.
static constexpr ::llvm::StringLiteral getArgumentName() {
return ::llvm::StringLiteral("aie-assign-buffer-addresses-basic");
static constexpr llvm::StringLiteral getArgumentName() {
return llvm::StringLiteral("aie-assign-buffer-addresses-basic");
}
::llvm::StringRef getArgument() const override {

llvm::StringRef getArgument() const override {
return "aie-assign-buffer-addresses-basic";
}

::llvm::StringRef getDescription() const override {
llvm::StringRef getDescription() const override {
return "Assign memory locations for buffers in each tile";
}

/// Returns the derived pass name.
static constexpr ::llvm::StringLiteral getPassName() {
return ::llvm::StringLiteral("AIEAssignBufferAddressesBasic");
static constexpr llvm::StringLiteral getPassName() {
return llvm::StringLiteral("AIEAssignBufferAddressesBasic");
}
::llvm::StringRef getName() const override {

llvm::StringRef getName() const override {
return "AIEAssignBufferAddressesBasic";
}

/// Support isa/dyn_cast functionality for the derived pass class.
static bool classof(const ::mlir::Pass *pass) {
return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
static bool classof(const mlir::Pass *pass) {
return pass->getTypeID() == mlir::TypeID::get<DerivedT>();
}

/// A clone method to create a copy of this pass.
std::unique_ptr<::mlir::Pass> clonePass() const override {
std::unique_ptr<mlir::Pass> clonePass() const override {
return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
}

/// Register the dialects that must be loaded in the context before this pass.
void getDependentDialects(::mlir::DialectRegistry &registry) const override {}
void getDependentDialects(mlir::DialectRegistry &registry) const override {}

/// Explicitly declare the TypeID for this class. We declare an explicit
/// private instantiation because Pass classes should only be visible by the
/// current library.
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
AIEAssignBufferAddressesPassBasicBase<DerivedT>)
};
Expand Down Expand Up @@ -208,32 +196,14 @@ struct AIEAssignBufferAddressesPassBasic
if (address > maxDataMemorySize) {
InFlightDiagnostic error =
tile.emitOpError("allocated buffers exceeded available memory\n");
auto &note = error.attachNote() << "MemoryMap:\n";
auto printbuffer = [&](StringRef name, int address, int size) {
note << "\t" << name << " \t"
<< ": 0x" << llvm::utohexstr(address) << "-0x"
<< llvm::utohexstr(address + size - 1) << " \t(" << size
<< " bytes)\n";
};
if (stacksize > 0)
printbuffer("(stack)", 0, stacksize);
else
error << "(no stack allocated)\n";

for (auto buffer : buffers) {
assert(buffer.getAddress().has_value() &&
"buffer must have address assigned");
printbuffer(buffer.name(), buffer.getAddress().value(),
buffer.getAllocationSize());
}
return signalPassFailure();
}
}
}
};

std::unique_ptr<OperationPass<DeviceOp>>
AIE::createAIEAssignBufferAddressesBasicPass() {
xilinx::AIE::createAIEAssignBufferAddressesBasicPass() {
return std::make_unique<AIEAssignBufferAddressesPassBasic>();
}

Expand Down Expand Up @@ -330,6 +300,7 @@ struct AIEAssignBufferDescriptorIDsPass
bd.setBdId(gen.nextBdId(blockChannelMap[&block]));
}
}

for (TileElement memOp : memOps) {
DenseMap<Block *, int> blockBdIdMap;
for (Block &block : memOp.getOperation()->getRegion(0)) {
Expand Down Expand Up @@ -362,7 +333,7 @@ struct AIEAssignBufferDescriptorIDsPass
};

std::unique_ptr<OperationPass<DeviceOp>>
AIE::createAIEAssignBufferDescriptorIDsPass() {
xilinx::AIE::createAIEAssignBufferDescriptorIDsPass() {
return std::make_unique<AIEAssignBufferDescriptorIDsPass>();
}

Expand Down Expand Up @@ -455,7 +426,8 @@ struct AIEAssignLockIDsPass
}
};

std::unique_ptr<OperationPass<DeviceOp>> AIE::createAIEAssignLockIDsPass() {
std::unique_ptr<OperationPass<DeviceOp>>
xilinx::AIE::createAIEAssignLockIDsPass() {
return std::make_unique<AIEAssignLockIDsPass>();
}

Expand All @@ -470,11 +442,6 @@ std::unique_ptr<OperationPass<DeviceOp>> AIE::createAIEAssignLockIDsPass() {
//
//===----------------------------------------------------------------------===//

using namespace mlir;
using namespace mlir::vector;
using namespace xilinx;
using namespace xilinx::AIE;

static StringRef getArchIntrinsicString(AIEArch arch) { return "aie2"; }

typedef std::tuple<const char *, std::vector<Type>, std::vector<Type>>
Expand Down Expand Up @@ -522,47 +489,6 @@ static void declareAIEIntrinsics(AIEArch arch, OpBuilder &builder) {
registerIntrinsics(getAIE2Intrinsics(builder));
}

template <typename MyAIEOp>
struct AIEOpRemoval : OpConversionPattern<MyAIEOp> {
using OpConversionPattern<MyAIEOp>::OpConversionPattern;
using OpAdaptor = typename MyAIEOp::Adaptor;
ModuleOp &module;

AIEOpRemoval(MLIRContext *context, ModuleOp &m, PatternBenefit benefit = 1)
: OpConversionPattern<MyAIEOp>(context, benefit), module(m) {}

LogicalResult matchAndRewrite(
MyAIEOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.eraseOp(op);
return success();
}
};

struct AIEDebugOpToStdLowering : OpConversionPattern<DebugOp> {
using OpConversionPattern::OpConversionPattern;
ModuleOp &module;

AIEDebugOpToStdLowering(MLIRContext *context, ModuleOp &m,
PatternBenefit benefit = 1)
: OpConversionPattern(context, benefit), module(m) {}

LogicalResult matchAndRewrite(
DebugOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
std::string funcName = "debug_i32";
auto func = module.lookupSymbol<func::FuncOp>(funcName);
if (!func)
return op.emitOpError("Could not find the intrinsic function ")
<< funcName;
SmallVector<Value, 1> args;
args.push_back(op.getArg());
rewriter.create<func::CallOp>(rewriter.getUnknownLoc(), func, args);
rewriter.eraseOp(op);
return success();
}
};

struct AIEPutStreamToStdLowering : OpConversionPattern<PutStreamOp> {
using OpConversionPattern::OpConversionPattern;
ModuleOp &module;
Expand Down Expand Up @@ -906,9 +832,6 @@ struct AIECoreToStandardPass
return signalPassFailure();
}
DeviceOp device = *m.getOps<DeviceOp>().begin();
AMDAIENPUDeviceModel &targetModel =
mlir::iree_compiler::AMDAIE::getDeviceModel();

// Ensure that we don't have an incorrect target triple. This may override
// some bogus target triple in the original mlir.
m->setAttr(LLVM::LLVMDialect::getTargetTripleAttrName(),
Expand Down Expand Up @@ -936,8 +859,8 @@ struct AIECoreToStandardPass
RewritePatternSet patterns(&getContext());
patterns.add<AIEPutStreamToStdLowering, AIEGetStreamToStdLowering,
AIEPutCascadeToStdLowering, AIEGetCascadeToStdLowering,
AIEDebugOpToStdLowering, AIEUseLockToStdLowering,
AIEEventOpToStdLowering>(m.getContext(), m);
AIEUseLockToStdLowering, AIEEventOpToStdLowering>(
m.getContext(), m);

patterns.add<AIEBufferToStandard>(m.getContext(), m, /*benefit*/ 1, tileCol,
tileRow);
Expand All @@ -956,21 +879,14 @@ struct AIECoreToStandardPass
outlineOps<memref::GlobalOp>(device);
outlineOps<func::FuncOp>(device);

RewritePatternSet removepatterns(&getContext());
removepatterns.add<
AIEOpRemoval<DeviceOp>, AIEOpRemoval<TileOp>, AIEOpRemoval<FlowOp>,
AIEOpRemoval<MemOp>, AIEOpRemoval<ShimDMAOp>, AIEOpRemoval<ShimMuxOp>,
AIEOpRemoval<SwitchboxOp>, AIEOpRemoval<LockOp>, AIEOpRemoval<BufferOp>,
AIEOpRemoval<ExternalBufferOp>, AIEOpRemoval<ShimDMAAllocationOp>,
AIEOpRemoval<CascadeFlowOp>, AIEOpRemoval<ConfigureCascadeOp>>(
m.getContext(), m);

if (failed(applyPartialConversion(m, target, std::move(removepatterns))))
return signalPassFailure();
MLIRContext &context = getContext();
IRRewriter rewriter(&context);
rewriter.eraseOp(device);
}
};

std::unique_ptr<OperationPass<ModuleOp>> AIE::createAIECoreToStandardPass() {
std::unique_ptr<OperationPass<ModuleOp>>
xilinx::AIE::createAIECoreToStandardPass() {
return std::make_unique<AIECoreToStandardPass>();
}

Expand Down Expand Up @@ -1476,7 +1392,8 @@ struct AIELocalizeLocksPass
}
};

std::unique_ptr<OperationPass<DeviceOp>> AIE::createAIELocalizeLocksPass() {
std::unique_ptr<OperationPass<DeviceOp>>
xilinx::AIE::createAIELocalizeLocksPass() {
return std::make_unique<AIELocalizeLocksPass>();
}

Expand Down Expand Up @@ -2810,7 +2727,7 @@ struct AIEObjectFifoStatefulTransformPass
};

std::unique_ptr<OperationPass<DeviceOp>>
AIE::createAIEObjectFifoStatefulTransformPass() {
xilinx::AIE::createAIEObjectFifoStatefulTransformPass() {
return std::make_unique<AIEObjectFifoStatefulTransformPass>();
}

Expand Down Expand Up @@ -3265,60 +3182,6 @@ std::optional<std::map<PathEndPoint, SwitchSettings>> Pathfinder::findPaths(
return routingSolution;
}

//===- AIEXToStandard.cpp ---------------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// Copyright (C) 2023, Advanced Micro Devices, Inc.
//
//===----------------------------------------------------------------------===//

using namespace xilinx::AIEX;

template <typename MyAIEXOp>
struct AIEXOpRemoval : OpConversionPattern<MyAIEXOp> {
using OpConversionPattern<MyAIEXOp>::OpConversionPattern;
using OpAdaptor = typename MyAIEXOp::Adaptor;
ModuleOp &module;

AIEXOpRemoval(MLIRContext *context, ModuleOp &m, PatternBenefit benefit = 1)
: OpConversionPattern<MyAIEXOp>(context, benefit), module(m) {}

LogicalResult matchAndRewrite(
MyAIEXOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Operation *Op = op.getOperation();
rewriter.eraseOp(Op);
return success();
}
};

struct AIEXToStandardPass
: xilinx::AIEX::impl::AIEXToStandardBase<AIEXToStandardPass> {
void runOnOperation() override {
ModuleOp m = getOperation();
ConversionTarget target(getContext());
RewritePatternSet removepatterns(&getContext());
removepatterns.add<AIEXOpRemoval<NpuDmaMemcpyNdOp>>(m.getContext(), m);
removepatterns.add<AIEXOpRemoval<NpuDmaWaitOp>>(m.getContext(), m);
removepatterns.add<AIEXOpRemoval<NpuPushQueueOp>>(m.getContext(), m);
removepatterns.add<AIEXOpRemoval<NpuWriteRTPOp>>(m.getContext(), m);
removepatterns.add<AIEXOpRemoval<NpuWrite32Op>>(m.getContext(), m);
removepatterns.add<AIEXOpRemoval<NpuSyncOp>>(m.getContext(), m);
removepatterns.add<AIEXOpRemoval<NpuWriteBdOp>>(m.getContext(), m);
removepatterns.add<AIEXOpRemoval<NpuAddressPatchOp>>(m.getContext(), m);

if (failed(applyPartialConversion(m, target, std::move(removepatterns))))
signalPassFailure();
}
};

std::unique_ptr<OperationPass<ModuleOp>> AIEX::createAIEXToStandardPass() {
return std::make_unique<AIEXToStandardPass>();
}

namespace mlir::iree_compiler::AMDAIE {
void registerAIETransformPasses() {
xilinx::AIE::registerAIEAssignLockIDs();
Expand All @@ -3332,8 +3195,5 @@ void registerAIETransformPasses() {
} // namespace mlir::iree_compiler::AMDAIE

namespace mlir::iree_compiler::AMDAIE {
void registerAIEXTransformPasses() {
xilinx::AIEX::registerAIEXToStandard();
xilinx::AIEX::registerAIEDmaToNpu();
}
void registerAIEXTransformPasses() { xilinx::AIEX::registerAIEDmaToNpu(); }
} // namespace mlir::iree_compiler::AMDAIE
Loading

0 comments on commit 0b2030a

Please sign in to comment.