diff --git a/include/aie-c/Translation.h b/include/aie-c/Translation.h index 4e092ef261..7f829bf2e5 100644 --- a/include/aie-c/Translation.h +++ b/include/aie-c/Translation.h @@ -8,6 +8,8 @@ #ifndef AIE_C_TRANSLATION_H #define AIE_C_TRANSLATION_H +#include "aie-c/TargetModel.h" + #include "mlir-c/IR.h" #include "mlir-c/Support.h" #include "mlir/CAPI/Wrap.h" @@ -43,7 +45,7 @@ struct AieRtxControl { }; using AieRtxControl = struct AieRtxControl; -MLIR_CAPI_EXPORTED AieRtxControl getAieRtxControl(size_t partitionNumCols); +MLIR_CAPI_EXPORTED AieRtxControl getAieRtxControl(AieTargetModel tm); MLIR_CAPI_EXPORTED void freeAieRtxControl(AieRtxControl aieCtl); MLIR_CAPI_EXPORTED void aieRtxStartTransaction(AieRtxControl aieCtl); MLIR_CAPI_EXPORTED void aieRtxDmaUpdateBdAddr(AieRtxControl aieCtl, int col, diff --git a/lib/CAPI/Translation.cpp b/lib/CAPI/Translation.cpp index 4a5490c812..240d5fb4b4 100644 --- a/lib/CAPI/Translation.cpp +++ b/lib/CAPI/Translation.cpp @@ -195,13 +195,10 @@ MlirStringRef aieLLVMLink(MlirStringRef *modules, int nModules) { DEFINE_C_API_PTR_METHODS(AieRtxControl, xilinx::AIE::AIERTXControl) -AieRtxControl getAieRtxControl(size_t partitionNumCols) { - std::vector devices{AIEDevice::npu1_1col, AIEDevice::npu1_2col, - AIEDevice::npu1_3col, AIEDevice::npu1_4col, - AIEDevice::npu1}; +AieRtxControl getAieRtxControl(AieTargetModel tm) { + // unwrap the target model const BaseNPUTargetModel &targetModel = - (const BaseNPUTargetModel &)xilinx::AIE::getTargetModel( - devices[partitionNumCols - 1]); + *reinterpret_cast(tm.d); AIERTXControl *ctl = new AIERTXControl(targetModel); return wrap(ctl); } diff --git a/python/AIEMLIRModule.cpp b/python/AIEMLIRModule.cpp index ea5d9dd457..e357290f18 100644 --- a/python/AIEMLIRModule.cpp +++ b/python/AIEMLIRModule.cpp @@ -10,6 +10,8 @@ #include "aie-c/TargetModel.h" #include "aie-c/Translation.h" +#include "PyTypes.h" + #include "mlir-c/IR.h" #include "mlir-c/Support.h" #include "mlir/Bindings/Python/PybindAdaptors.h" @@ -28,16 +30,6 @@ using namespace mlir::python::adaptors; namespace py = pybind11; using namespace py::literals; -class PyAieTargetModel { -public: - PyAieTargetModel(AieTargetModel model) : model(model) {} - operator AieTargetModel() const { return model; } - AieTargetModel get() const { return model; } - -private: - AieTargetModel model; -}; - PYBIND11_MODULE(_aie, m) { aieRegisterAllPasses(); diff --git a/python/AIERTXModule.cpp b/python/AIERTXModule.cpp index 39f856b49d..e778bcf5b6 100644 --- a/python/AIERTXModule.cpp +++ b/python/AIERTXModule.cpp @@ -8,8 +8,11 @@ // //===----------------------------------------------------------------------===// +#include "aie-c/TargetModel.h" #include "aie-c/Translation.h" +#include "PyTypes.h" + #include #include @@ -20,8 +23,8 @@ using namespace py::literals; class PyAIERTXControl { public: - PyAIERTXControl(size_t partitionNumCols) - : ctl(getAieRtxControl(partitionNumCols)) {} + PyAIERTXControl(AieTargetModel targetModel) + : ctl(getAieRtxControl(targetModel)) {} ~PyAIERTXControl() { freeAieRtxControl(ctl); } @@ -31,7 +34,7 @@ class PyAIERTXControl { PYBIND11_MODULE(_aiertx, m) { py::class_(m, "AIERTXControl", py::module_local()) - .def(py::init(), "partition_num_cols"_a) + .def(py::init(), "target_model"_a) .def("start_transaction", [](PyAIERTXControl &self) { aieRtxStartTransaction(self.ctl); }) .def("export_serialized_transaction", diff --git a/python/PyTypes.h b/python/PyTypes.h new file mode 100644 index 0000000000..589131fb4e --- /dev/null +++ b/python/PyTypes.h @@ -0,0 +1,25 @@ +//===- PyTypes.h ------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// (c) Copyright 2024 Advanced Micro Devices, Inc. +// +//===----------------------------------------------------------------------===// +#ifndef AIE_PYTYPES_H +#define AIE_PYTYPES_H + +#include "aie-c/TargetModel.h" + +class PyAieTargetModel { +public: + PyAieTargetModel(AieTargetModel model) : model(model) {} + operator AieTargetModel() const { return model; } + AieTargetModel get() const { return model; } + +private: + AieTargetModel model; +}; + +#endif // AIE_PYTYPES_H \ No newline at end of file diff --git a/test/python/aiertx_bindings.py b/test/python/aiertx_bindings.py index 730ba7444c..670b1317c4 100644 --- a/test/python/aiertx_bindings.py +++ b/test/python/aiertx_bindings.py @@ -9,12 +9,13 @@ from aie.aiertx import AIERTXControl from util import construct_and_print_module from aie.dialects.aiex import DDR_AIE_ADDR_OFFSET - +from aie.dialects.aie import AIEDevice, get_target_model # CHECK-LABEL: simple @construct_and_print_module def simple(module): - ctl = AIERTXControl(4) + tm = get_target_model(AIEDevice.npu1_4col) + ctl = AIERTXControl(tm) ctl.start_transaction() ctl.dma_update_bd_addr(0, 0, DDR_AIE_ADDR_OFFSET, 0) ctl.export_serialized_transaction()