Skip to content

Commit

Permalink
[mlir][python] expose LLVMStructType API (#81672)
Browse files Browse the repository at this point in the history
Expose the API for constructing and inspecting StructTypes from the LLVM
dialect. Separate constructor methods are used instead of overloads for
better readability, similarly to IntegerType.
  • Loading branch information
ftynse authored Feb 14, 2024
1 parent 6c84709 commit bd8fcf7
Show file tree
Hide file tree
Showing 7 changed files with 525 additions and 3 deletions.
61 changes: 60 additions & 1 deletion mlir/include/mlir-c/Dialect/LLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,70 @@ MLIR_CAPI_EXPORTED MlirType
mlirLLVMFunctionTypeGet(MlirType resultType, intptr_t nArgumentTypes,
MlirType const *argumentTypes, bool isVarArg);

/// Creates an LLVM literal (unnamed) struct type.
/// Returns `true` if the type is an LLVM dialect struct type.
MLIR_CAPI_EXPORTED bool mlirTypeIsALLVMStructType(MlirType type);

/// Returns `true` if the type is a literal (unnamed) LLVM struct type.
MLIR_CAPI_EXPORTED bool mlirLLVMStructTypeIsLiteral(MlirType type);

/// Returns the number of fields in the struct. Asserts if the struct is opaque
/// or not yet initialized.
MLIR_CAPI_EXPORTED intptr_t mlirLLVMStructTypeGetNumElementTypes(MlirType type);

/// Returns the `positions`-th field of the struct. Asserts if the struct is
/// opaque, not yet initialized or if the position is out of range.
MLIR_CAPI_EXPORTED MlirType mlirLLVMStructTypeGetElementType(MlirType type,
intptr_t position);

/// Returns `true` if the struct is packed.
MLIR_CAPI_EXPORTED bool mlirLLVMStructTypeIsPacked(MlirType type);

/// Returns the identifier of the identified struct. Asserts that the struct is
/// identified, i.e., not literal.
MLIR_CAPI_EXPORTED MlirStringRef mlirLLVMStructTypeGetIdentifier(MlirType type);

/// Returns `true` is the struct is explicitly opaque (will not have a body) or
/// uninitialized (will eventually have a body).
MLIR_CAPI_EXPORTED bool mlirLLVMStructTypeIsOpaque(MlirType type);

/// Creates an LLVM literal (unnamed) struct type. This may assert if the fields
/// have types not compatible with the LLVM dialect. For a graceful failure, use
/// the checked version.
MLIR_CAPI_EXPORTED MlirType
mlirLLVMStructTypeLiteralGet(MlirContext ctx, intptr_t nFieldTypes,
MlirType const *fieldTypes, bool isPacked);

/// Creates an LLVM literal (unnamed) struct type if possible. Emits a
/// diagnostic at the given location and returns null otherwise.
MLIR_CAPI_EXPORTED MlirType
mlirLLVMStructTypeLiteralGetChecked(MlirLocation loc, intptr_t nFieldTypes,
MlirType const *fieldTypes, bool isPacked);

/// Creates an LLVM identified struct type with no body. If a struct type with
/// this name already exists in the context, returns that type. Use
/// mlirLLVMStructTypeIdentifiedNewGet to create a fresh struct type,
/// potentially renaming it. The body should be set separatelty by calling
/// mlirLLVMStructTypeSetBody, if it isn't set already.
MLIR_CAPI_EXPORTED MlirType mlirLLVMStructTypeIdentifiedGet(MlirContext ctx,
MlirStringRef name);

/// Creates an LLVM identified struct type with no body and a name starting with
/// the given prefix. If a struct with the exact name as the given prefix
/// already exists, appends an unspecified suffix to the name so that the name
/// is unique in context.
MLIR_CAPI_EXPORTED MlirType mlirLLVMStructTypeIdentifiedNewGet(
MlirContext ctx, MlirStringRef name, intptr_t nFieldTypes,
MlirType const *fieldTypes, bool isPacked);

MLIR_CAPI_EXPORTED MlirType mlirLLVMStructTypeOpaqueGet(MlirContext ctx,
MlirStringRef name);

/// Sets the body of the identified struct if it hasn't been set yet. Returns
/// whether the operation was successful.
MLIR_CAPI_EXPORTED MlirLogicalResult
mlirLLVMStructTypeSetBody(MlirType structType, intptr_t nFieldTypes,
MlirType const *fieldTypes, bool isPacked);

#ifdef __cplusplus
}
#endif
Expand Down
145 changes: 145 additions & 0 deletions mlir/lib/Bindings/Python/DialectLLVM.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
//===- DialectLLVM.cpp - Pybind module for LLVM dialect API support -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir-c/Diagnostics.h"
#include "mlir-c/Dialect/LLVM.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include <string>

namespace py = pybind11;
using namespace llvm;
using namespace mlir;
using namespace mlir::python;
using namespace mlir::python::adaptors;

/// RAII scope intercepting all diagnostics into a string. The message must be
/// checked before this goes out of scope.
class CollectDiagnosticsToStringScope {
public:
explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) {
handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage,
/*deleteUserData=*/nullptr);
}
~CollectDiagnosticsToStringScope() {
assert(errorMessage.empty() && "unchecked error message");
mlirContextDetachDiagnosticHandler(context, handlerID);
}

[[nodiscard]] std::string takeMessage() { return std::move(errorMessage); }

private:
static MlirLogicalResult handler(MlirDiagnostic diag, void *data) {
auto printer = +[](MlirStringRef message, void *data) {
*static_cast<std::string *>(data) +=
StringRef(message.data, message.length);
};
mlirDiagnosticPrint(diag, printer, data);
return mlirLogicalResultSuccess();
}

MlirContext context;
MlirDiagnosticHandlerID handlerID;
std::string errorMessage = "";
};

void populateDialectLLVMSubmodule(const pybind11::module &m) {
auto llvmStructType =
mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType);

llvmStructType.def_classmethod(
"get_literal",
[](py::object cls, const std::vector<MlirType> &elements, bool packed,
MlirLocation loc) {
CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));

MlirType type = mlirLLVMStructTypeLiteralGetChecked(
loc, elements.size(), elements.data(), packed);
if (mlirTypeIsNull(type)) {
throw py::value_error(scope.takeMessage());
}
return cls(type);
},
py::arg("cls"), py::arg("elements"), py::kw_only(),
py::arg("packed") = false, py::arg("loc") = py::none());

llvmStructType.def_classmethod(
"get_identified",
[](py::object cls, const std::string &name, MlirContext context) {
return cls(mlirLLVMStructTypeIdentifiedGet(
context, mlirStringRefCreate(name.data(), name.size())));
},
py::arg("cls"), py::arg("name"), py::kw_only(),
py::arg("context") = py::none());

llvmStructType.def_classmethod(
"get_opaque",
[](py::object cls, const std::string &name, MlirContext context) {
return cls(mlirLLVMStructTypeOpaqueGet(
context, mlirStringRefCreate(name.data(), name.size())));
},
py::arg("cls"), py::arg("name"), py::arg("context") = py::none());

llvmStructType.def(
"set_body",
[](MlirType self, const std::vector<MlirType> &elements, bool packed) {
MlirLogicalResult result = mlirLLVMStructTypeSetBody(
self, elements.size(), elements.data(), packed);
if (!mlirLogicalResultIsSuccess(result)) {
throw py::value_error(
"Struct body already set to different content.");
}
},
py::arg("elements"), py::kw_only(), py::arg("packed") = false);

llvmStructType.def_classmethod(
"new_identified",
[](py::object cls, const std::string &name,
const std::vector<MlirType> &elements, bool packed, MlirContext ctx) {
return cls(mlirLLVMStructTypeIdentifiedNewGet(
ctx, mlirStringRefCreate(name.data(), name.length()),
elements.size(), elements.data(), packed));
},
py::arg("cls"), py::arg("name"), py::arg("elements"), py::kw_only(),
py::arg("packed") = false, py::arg("context") = py::none());

llvmStructType.def_property_readonly(
"name", [](MlirType type) -> std::optional<std::string> {
if (mlirLLVMStructTypeIsLiteral(type))
return std::nullopt;

MlirStringRef stringRef = mlirLLVMStructTypeGetIdentifier(type);
return StringRef(stringRef.data, stringRef.length).str();
});

llvmStructType.def_property_readonly("body", [](MlirType type) -> py::object {
// Don't crash in absence of a body.
if (mlirLLVMStructTypeIsOpaque(type))
return py::none();

py::list body;
for (intptr_t i = 0, e = mlirLLVMStructTypeGetNumElementTypes(type); i < e;
++i) {
body.append(mlirLLVMStructTypeGetElementType(type, i));
}
return body;
});

llvmStructType.def_property_readonly(
"packed", [](MlirType type) { return mlirLLVMStructTypeIsPacked(type); });

llvmStructType.def_property_readonly(
"opaque", [](MlirType type) { return mlirLLVMStructTypeIsOpaque(type); });
}

PYBIND11_MODULE(_mlirDialectsLLVM, m) {
m.doc() = "MLIR LLVM Dialect";

populateDialectLLVMSubmodule(m);
}
68 changes: 67 additions & 1 deletion mlir/lib/CAPI/Dialect/LLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,77 @@ MlirType mlirLLVMFunctionTypeGet(MlirType resultType, intptr_t nArgumentTypes,
unwrapList(nArgumentTypes, argumentTypes, argumentStorage), isVarArg));
}

bool mlirTypeIsALLVMStructType(MlirType type) {
return isa<LLVM::LLVMStructType>(unwrap(type));
}

bool mlirLLVMStructTypeIsLiteral(MlirType type) {
return !cast<LLVM::LLVMStructType>(unwrap(type)).isIdentified();
}

intptr_t mlirLLVMStructTypeGetNumElementTypes(MlirType type) {
return cast<LLVM::LLVMStructType>(unwrap(type)).getBody().size();
}

MlirType mlirLLVMStructTypeGetElementType(MlirType type, intptr_t position) {
return wrap(cast<LLVM::LLVMStructType>(unwrap(type)).getBody()[position]);
}

bool mlirLLVMStructTypeIsPacked(MlirType type) {
return cast<LLVM::LLVMStructType>(unwrap(type)).isPacked();
}

MlirStringRef mlirLLVMStructTypeGetIdentifier(MlirType type) {
return wrap(cast<LLVM::LLVMStructType>(unwrap(type)).getName());
}

bool mlirLLVMStructTypeIsOpaque(MlirType type) {
return cast<LLVM::LLVMStructType>(unwrap(type)).isOpaque();
}

MlirType mlirLLVMStructTypeLiteralGet(MlirContext ctx, intptr_t nFieldTypes,
MlirType const *fieldTypes,
bool isPacked) {
SmallVector<Type, 2> fieldStorage;
SmallVector<Type> fieldStorage;
return wrap(LLVMStructType::getLiteral(
unwrap(ctx), unwrapList(nFieldTypes, fieldTypes, fieldStorage),
isPacked));
}

MlirType mlirLLVMStructTypeLiteralGetChecked(MlirLocation loc,
intptr_t nFieldTypes,
MlirType const *fieldTypes,
bool isPacked) {
SmallVector<Type> fieldStorage;
return wrap(LLVMStructType::getLiteralChecked(
[loc]() { return emitError(unwrap(loc)); }, unwrap(loc)->getContext(),
unwrapList(nFieldTypes, fieldTypes, fieldStorage), isPacked));
}

MlirType mlirLLVMStructTypeOpaqueGet(MlirContext ctx, MlirStringRef name) {
return wrap(LLVMStructType::getOpaque(unwrap(name), unwrap(ctx)));
}

MlirType mlirLLVMStructTypeIdentifiedGet(MlirContext ctx, MlirStringRef name) {
return wrap(LLVMStructType::getIdentified(unwrap(ctx), unwrap(name)));
}

MlirType mlirLLVMStructTypeIdentifiedNewGet(MlirContext ctx, MlirStringRef name,
intptr_t nFieldTypes,
MlirType const *fieldTypes,
bool isPacked) {
SmallVector<Type> fields;
return wrap(LLVMStructType::getNewIdentified(
unwrap(ctx), unwrap(name), unwrapList(nFieldTypes, fieldTypes, fields),
isPacked));
}

MlirLogicalResult mlirLLVMStructTypeSetBody(MlirType structType,
intptr_t nFieldTypes,
MlirType const *fieldTypes,
bool isPacked) {
SmallVector<Type> fields;
return wrap(
cast<LLVM::LLVMStructType>(unwrap(structType))
.setBody(unwrapList(nFieldTypes, fieldTypes, fields), isPacked));
}
13 changes: 13 additions & 0 deletions mlir/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,19 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Pybind
MLIRCAPILinalg
)

declare_mlir_python_extension(MLIRPythonExtension.Dialects.LLVM.Pybind
MODULE_NAME _mlirDialectsLLVM
ADD_TO_PARENT MLIRPythonSources.Dialects.llvm
ROOT_DIR "${PYTHON_SOURCE_DIR}"
SOURCES
DialectLLVM.cpp
PRIVATE_LINK_LIBS
LLVMSupport
EMBED_CAPI_LINK_LIBS
MLIRCAPIIR
MLIRCAPILLVM
)

declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Pybind
MODULE_NAME _mlirDialectsQuant
ADD_TO_PARENT MLIRPythonSources.Dialects.quant
Expand Down
1 change: 1 addition & 0 deletions mlir/python/mlir/dialects/llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@

from ._llvm_ops_gen import *
from ._llvm_enum_gen import *
from .._mlir_libs._mlirDialectsLLVM import *
Loading

0 comments on commit bd8fcf7

Please sign in to comment.