Skip to content

[mlir][python] remove mixins #68853

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 58 additions & 69 deletions mlir/docs/Bindings/Python.md
Original file line number Diff line number Diff line change
Expand Up @@ -1017,90 +1017,79 @@ very generic signature.

#### Extending Generated Op Classes

Note that this is a rather complex mechanism and this section errs on the side
of explicitness. Users are encouraged to find an example and duplicate it if
they don't feel the need to understand the subtlety. The `builtin` dialect
provides some relatively simple examples.

As mentioned above, the build system generates Python sources like
`_{DIALECT_NAMESPACE}_ops_gen.py` for each dialect with Python bindings. It is
often desirable to to use these generated classes as a starting point for
further customization, so an extension mechanism is provided to make this easy
(you are always free to do ad-hoc patching in your `{DIALECT_NAMESPACE}.py` file
but we prefer a more standard mechanism that is applied uniformly).
often desirable to use these generated classes as a starting point for
further customization, so an extension mechanism is provided to make this easy.
This mechanism uses conventional inheritance combined with `OpView` registration.
For example, the default builder for `arith.constant`

```python
class ConstantOp(_ods_ir.OpView):
OPERATION_NAME = "arith.constant"

_ODS_REGIONS = (0, True)

def __init__(self, value, *, loc=None, ip=None):
...
```

To provide extensions, add a `_{DIALECT_NAMESPACE}_ops_ext.py` file to the
`dialects` module (i.e. adjacent to your `{DIALECT_NAMESPACE}.py` top-level and
the `*_ops_gen.py` file). Using the `builtin` dialect and `FuncOp` as an
example, the generated code will include an import like this:
expects `value` to be a `TypedAttr` (e.g., `IntegerAttr` or `FloatAttr`).
Thus, a natural extension is a builder that accepts a MLIR type and a Python value and instantiates the appropriate `TypedAttr`:

```python
try:
from . import _builtin_ops_ext as _ods_ext_module
except ImportError:
_ods_ext_module = None
from typing import Union

from mlir.ir import Type, IntegerAttr, FloatAttr
from mlir.dialects._arith_ops_gen import _Dialect, ConstantOp
from mlir.dialects._ods_common import _cext

@_cext.register_operation(_Dialect, replace=True)
class ConstantOpExt(ConstantOp):
def __init__(
self, result: Type, value: Union[int, float], *, loc=None, ip=None
):
if isinstance(value, int):
super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
elif isinstance(value, float):
super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
else:
raise NotImplementedError(f"Building `arith.constant` not supported for {result=} {value=}")
```

Then for each generated concrete `OpView` subclass, it will apply a decorator
like:
which enables building an instance of `arith.constant` like so:

```python
@_ods_cext.register_operation(_Dialect)
@_ods_extend_opview_class(_ods_ext_module)
class FuncOp(_ods_ir.OpView):
from mlir.ir import F32Type

a = ConstantOpExt(F32Type.get(), 42.42)
b = ConstantOpExt(IntegerType.get_signless(32), 42)
```

See the `_ods_common.py` `extend_opview_class` function for details of the
mechanism. At a high level:

* If the extension module exists, locate an extension class for the op (in
this example, `FuncOp`):
* First by looking for an attribute with the exact name in the extension
module.
* Falling back to calling a `select_opview_mixin(parent_opview_cls)`
function defined in the extension module.
* If a mixin class is found, a new subclass is dynamically created that
multiply inherits from `({_builtin_ops_ext.FuncOp},
_builtin_ops_gen.FuncOp)`.

The mixin class should not inherit from anything (i.e. directly extends `object`
only). The facility is typically used to define custom `__init__` methods,
properties, instance methods and static methods. Due to the inheritance
ordering, the mixin class can act as though it extends the generated `OpView`
subclass in most contexts (i.e. `issubclass(_builtin_ops_ext.FuncOp, OpView)`
will return `False` but usage generally allows you treat it as duck typed as an
`OpView`).

There are a couple of recommendations, given how the class hierarchy is defined:

* For static methods that need to instantiate the actual "leaf" op (which is
dynamically generated and would result in circular dependencies to try to
reference by name), prefer to use `@classmethod` and the concrete subclass
will be provided as your first `cls` argument. See
`_builtin_ops_ext.FuncOp.from_py_func` as an example.
* If seeking to replace the generated `__init__` method entirely, you may
actually want to invoke the super-super-class `mlir.ir.OpView` constructor
directly, as it takes an `mlir.ir.Operation`, which is likely what you are
constructing (i.e. the generated `__init__` method likely adds more API
constraints than you want to expose in a custom builder).

A pattern that comes up frequently is wanting to provide a sugared `__init__`
method which has optional or type-polymorphism/implicit conversions but to
otherwise want to invoke the default op building logic. For such cases, it is
recommended to use an idiom such as:
Note, three key aspects of the extension mechanism in this example:

1. `ConstantOpExt` directly inherits from the generated `ConstantOp`;
2. in this, simplest, case all that's required is a call to the super class' initializer, i.e., `super().__init__(...)`;
3. in order to register `ConstantOpExt` as the preferred `OpView` that is returned by `mlir.ir.Operation.opview` (see [Operations, Regions and Blocks](#operations-regions-and-blocks))
we decorate the class with `@_cext.register_operation(_Dialect, replace=True)`, **where the `replace=True` must be used**.

In some more complex cases it might be necessary to explicitly build the `OpView` through `OpView.build_generic` (see [Default Builder](#default-builder)), just as is performed by the generated builders.
I.e., we must call `OpView.build_generic` **and pass the result to `OpView.__init__`**, where the small issue becomes that the latter is already overridden by the generated builder.
Thus, we must call a method of a super class' super class (the "grandparent"); for example:

```python
def __init__(self, sugar, spice, *, loc=None, ip=None):
... massage into result_type, operands, attributes ...
OpView.__init__(self, self.build_generic(
results=[result_type],
operands=operands,
attributes=attributes,
loc=loc,
ip=ip))
from mlir.dialects._scf_ops_gen import _Dialect, ForOp
from mlir.dialects._ods_common import _cext

@_cext.register_operation(_Dialect, replace=True)
class ForOpExt(ForOp):
def __init__(self, lower_bound, upper_bound, step, iter_args, *, loc=None, ip=None):
...
super(ForOp, self).__init__(self.build_generic(...))
```

Refer to the documentation for `build_generic` for more information.
where `OpView.__init__` is called via `super(ForOp, self).__init__`.
Note, there are alternatives ways to implement this (e.g., explicitly writing `OpView.__init__`); see any discussion on Python inheritance.

## Providing Python bindings for a dialect

Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Bindings/Python/Globals.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ class PyGlobals {
pybind11::object pyClass);

/// Adds a concrete implementation operation class.
/// Raises an exception if the mapping already exists.
/// Raises an exception if the mapping already exists and replace == false.
/// This is intended to be called by implementation code.
void registerOperationImpl(const std::string &operationName,
pybind11::object pyClass);
pybind11::object pyClass, bool replace = false);

/// Returns the custom Attribute builder for Attribute kind.
std::optional<pybind11::function>
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Bindings/Python/IRModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
}

void PyGlobals::registerOperationImpl(const std::string &operationName,
py::object pyClass) {
py::object pyClass, bool replace) {
py::object &found = operationClassMap[operationName];
if (found) {
if (found && !replace) {
throw std::runtime_error((llvm::Twine("Operation '") + operationName +
"' is already registered.")
.str());
Expand Down
11 changes: 6 additions & 5 deletions mlir/lib/Bindings/Python/MainModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ PYBIND11_MODULE(_mlir, m) {
"dialect_namespace"_a, "dialect_class"_a,
"Testing hook for directly registering a dialect")
.def("_register_operation_impl", &PyGlobals::registerOperationImpl,
"operation_name"_a, "operation_class"_a,
"operation_name"_a, "operation_class"_a, "replace"_a = false,
"Testing hook for directly registering an operation");

// Aside from making the globals accessible to python, having python manage
Expand All @@ -63,20 +63,21 @@ PYBIND11_MODULE(_mlir, m) {
"Class decorator for registering a custom Dialect wrapper");
m.def(
"register_operation",
[](const py::object &dialectClass) -> py::cpp_function {
[](const py::object &dialectClass, bool replace) -> py::cpp_function {
return py::cpp_function(
[dialectClass](py::object opClass) -> py::object {
[dialectClass, replace](py::object opClass) -> py::object {
std::string operationName =
opClass.attr("OPERATION_NAME").cast<std::string>();
PyGlobals::get().registerOperationImpl(operationName, opClass);
PyGlobals::get().registerOperationImpl(operationName, opClass,
replace);

// Dict-stuff the new opClass by name onto the dialect class.
py::object opClassName = opClass.attr("__name__");
dialectClass.attr(opClassName) = opClass;
return opClass;
});
},
"dialect_class"_a,
"dialect_class"_a, "replace"_a = false,
"Produce a class decorator for registering an Operation class as part of "
"a dialect");
m.def(
Expand Down
19 changes: 0 additions & 19 deletions mlir/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/AffineOps.td
SOURCES
dialects/affine.py
dialects/_affine_ops_ext.py
DIALECT_NAME affine
GEN_ENUM_BINDINGS)

Expand All @@ -78,7 +77,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/BufferizationOps.td
SOURCES
dialects/bufferization.py
dialects/_bufferization_ops_ext.py
DIALECT_NAME bufferization
GEN_ENUM_BINDINGS_TD_FILE
"../../include/mlir/Dialect/Bufferization/IR/BufferizationEnums.td"
Expand All @@ -90,7 +88,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/BuiltinOps.td
SOURCES
dialects/builtin.py
dialects/_builtin_ops_ext.py
DIALECT_NAME builtin)

declare_mlir_dialect_python_bindings(
Expand All @@ -115,7 +112,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/FuncOps.td
SOURCES
dialects/func.py
dialects/_func_ops_ext.py
DIALECT_NAME func)

declare_mlir_dialect_python_bindings(
Expand All @@ -131,7 +127,6 @@ declare_mlir_dialect_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/LinalgOps.td
SOURCES
dialects/_linalg_ops_ext.py
SOURCES_GLOB
dialects/linalg/*.py
DIALECT_NAME linalg
Expand All @@ -152,7 +147,6 @@ ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/TransformPDLExtensionOps.td
SOURCES
dialects/_transform_pdl_extension_ops_ext.py
dialects/transform/pdl.py
DIALECT_NAME transform
EXTENSION_NAME transform_pdl_extension)
Expand All @@ -162,7 +156,6 @@ declare_mlir_dialect_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/TransformOps.td
SOURCES
dialects/_transform_ops_ext.py
dialects/transform/__init__.py
_mlir_libs/_mlir/dialects/transform/__init__.pyi
DIALECT_NAME transform
Expand All @@ -175,7 +168,6 @@ declare_mlir_dialect_extension_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/BufferizationTransformOps.td
SOURCES
dialects/_bufferization_transform_ops_ext.py
dialects/transform/bufferization.py
DIALECT_NAME transform
EXTENSION_NAME bufferization_transform)
Expand All @@ -185,7 +177,6 @@ declare_mlir_dialect_extension_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/GPUTransformOps.td
SOURCES
dialects/_gpu_transform_ops_ext.py
dialects/transform/gpu.py
DIALECT_NAME transform
EXTENSION_NAME gpu_transform)
Expand All @@ -195,7 +186,6 @@ declare_mlir_dialect_extension_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/SCFLoopTransformOps.td
SOURCES
dialects/_loop_transform_ops_ext.py
dialects/transform/loop.py
DIALECT_NAME transform
EXTENSION_NAME loop_transform)
Expand All @@ -205,7 +195,6 @@ declare_mlir_dialect_extension_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/MemRefTransformOps.td
SOURCES
dialects/_memref_transform_ops_ext.py
dialects/transform/memref.py
DIALECT_NAME transform
EXTENSION_NAME memref_transform)
Expand All @@ -224,7 +213,6 @@ declare_mlir_dialect_extension_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/LinalgStructuredTransformOps.td
SOURCES
dialects/_structured_transform_ops_ext.py
dialects/transform/structured.py
DIALECT_NAME transform
EXTENSION_NAME structured_transform
Expand All @@ -246,7 +234,6 @@ declare_mlir_dialect_extension_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/TensorTransformOps.td
SOURCES
dialects/_tensor_transform_ops_ext.py
dialects/transform/tensor.py
DIALECT_NAME transform
EXTENSION_NAME tensor_transform)
Expand Down Expand Up @@ -276,7 +263,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/ArithOps.td
SOURCES
dialects/arith.py
dialects/_arith_ops_ext.py
DIALECT_NAME arith
GEN_ENUM_BINDINGS)

Expand All @@ -286,7 +272,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/MemRefOps.td
SOURCES
dialects/memref.py
dialects/_memref_ops_ext.py
DIALECT_NAME memref)

declare_mlir_dialect_python_bindings(
Expand All @@ -295,7 +280,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/MLProgramOps.td
SOURCES
dialects/ml_program.py
dialects/_ml_program_ops_ext.py
DIALECT_NAME ml_program)

declare_mlir_dialect_python_bindings(
Expand Down Expand Up @@ -339,7 +323,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/PDLOps.td
SOURCES
dialects/pdl.py
dialects/_pdl_ops_ext.py
_mlir_libs/_mlir/dialects/pdl.pyi
DIALECT_NAME pdl)

Expand All @@ -357,7 +340,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/SCFOps.td
SOURCES
dialects/scf.py
dialects/_scf_ops_ext.py
DIALECT_NAME scf)

declare_mlir_dialect_python_bindings(
Expand All @@ -383,7 +365,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/TensorOps.td
SOURCES
dialects/tensor.py
dialects/_tensor_ops_ext.py
DIALECT_NAME tensor)

declare_mlir_dialect_python_bindings(
Expand Down
Loading