Skip to content

Commit

Permalink
Update bindings to use TargetModel
Browse files Browse the repository at this point in the history
  • Loading branch information
fifield committed Sep 9, 2024
1 parent 85ee662 commit 5522d13
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 22 deletions.
4 changes: 3 additions & 1 deletion include/aie-c/Translation.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#ifndef AIE_C_TRANSLATION_H
#define AIE_C_TRANSLATION_H

#include "aie-c/TargetModel.h"

#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/CAPI/Wrap.h"
Expand Down Expand Up @@ -43,7 +45,7 @@ struct AieRtxControl {
};
using AieRtxControl = struct AieRtxControl;

MLIR_CAPI_EXPORTED AieRtxControl getAieRtxControl(size_t partitionNumCols);
MLIR_CAPI_EXPORTED AieRtxControl getAieRtxControl(AieTargetModel tm);
MLIR_CAPI_EXPORTED void freeAieRtxControl(AieRtxControl aieCtl);
MLIR_CAPI_EXPORTED void aieRtxStartTransaction(AieRtxControl aieCtl);
MLIR_CAPI_EXPORTED void aieRtxDmaUpdateBdAddr(AieRtxControl aieCtl, int col,
Expand Down
9 changes: 3 additions & 6 deletions lib/CAPI/Translation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,13 +195,10 @@ MlirStringRef aieLLVMLink(MlirStringRef *modules, int nModules) {

DEFINE_C_API_PTR_METHODS(AieRtxControl, xilinx::AIE::AIERTXControl)

AieRtxControl getAieRtxControl(size_t partitionNumCols) {
std::vector<AIEDevice> devices{AIEDevice::npu1_1col, AIEDevice::npu1_2col,
AIEDevice::npu1_3col, AIEDevice::npu1_4col,
AIEDevice::npu1};
AieRtxControl getAieRtxControl(AieTargetModel tm) {
// unwrap the target model
const BaseNPUTargetModel &targetModel =
(const BaseNPUTargetModel &)xilinx::AIE::getTargetModel(
devices[partitionNumCols - 1]);
*reinterpret_cast<const BaseNPUTargetModel *>(tm.d);
AIERTXControl *ctl = new AIERTXControl(targetModel);
return wrap(ctl);
}
Expand Down
12 changes: 2 additions & 10 deletions python/AIEMLIRModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#include "aie-c/TargetModel.h"
#include "aie-c/Translation.h"

#include "PyTypes.h"

#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
Expand All @@ -28,16 +30,6 @@ using namespace mlir::python::adaptors;
namespace py = pybind11;
using namespace py::literals;

class PyAieTargetModel {
public:
PyAieTargetModel(AieTargetModel model) : model(model) {}
operator AieTargetModel() const { return model; }
AieTargetModel get() const { return model; }

private:
AieTargetModel model;
};

PYBIND11_MODULE(_aie, m) {

aieRegisterAllPasses();
Expand Down
9 changes: 6 additions & 3 deletions python/AIERTXModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
//
//===----------------------------------------------------------------------===//

#include "aie-c/TargetModel.h"
#include "aie-c/Translation.h"

#include "PyTypes.h"

#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>

Expand All @@ -20,8 +23,8 @@ using namespace py::literals;

class PyAIERTXControl {
public:
PyAIERTXControl(size_t partitionNumCols)
: ctl(getAieRtxControl(partitionNumCols)) {}
PyAIERTXControl(AieTargetModel targetModel)
: ctl(getAieRtxControl(targetModel)) {}

~PyAIERTXControl() { freeAieRtxControl(ctl); }

Expand All @@ -31,7 +34,7 @@ class PyAIERTXControl {
PYBIND11_MODULE(_aiertx, m) {

py::class_<PyAIERTXControl>(m, "AIERTXControl", py::module_local())
.def(py::init<size_t>(), "partition_num_cols"_a)
.def(py::init<PyAieTargetModel>(), "target_model"_a)
.def("start_transaction",
[](PyAIERTXControl &self) { aieRtxStartTransaction(self.ctl); })
.def("export_serialized_transaction",
Expand Down
25 changes: 25 additions & 0 deletions python/PyTypes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
//===- PyTypes.h ------------------------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// (c) Copyright 2024 Advanced Micro Devices, Inc.
//
//===----------------------------------------------------------------------===//
#ifndef AIE_PYTYPES_H
#define AIE_PYTYPES_H

#include "aie-c/TargetModel.h"

class PyAieTargetModel {
public:
PyAieTargetModel(AieTargetModel model) : model(model) {}
operator AieTargetModel() const { return model; }
AieTargetModel get() const { return model; }

private:
AieTargetModel model;
};

#endif // AIE_PYTYPES_H
5 changes: 3 additions & 2 deletions test/python/aiertx_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
from aie.aiertx import AIERTXControl
from util import construct_and_print_module
from aie.dialects.aiex import DDR_AIE_ADDR_OFFSET

from aie.dialects.aie import AIEDevice, get_target_model

# CHECK-LABEL: simple
@construct_and_print_module
def simple(module):
ctl = AIERTXControl(4)
tm = get_target_model(AIEDevice.npu1_4col)
ctl = AIERTXControl(tm)
ctl.start_transaction()
ctl.dma_update_bd_addr(0, 0, DDR_AIE_ADDR_OFFSET, 0)
ctl.export_serialized_transaction()

0 comments on commit 5522d13

Please sign in to comment.