Skip to content

[mlir][python] remove mixins #68853

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 19, 2023
Merged

Conversation

makslevental
Copy link
Contributor

@makslevental makslevental commented Oct 12, 2023

This PR replaces the mixin OpView extension mechanism with the standard inheritance mechanism.

Why? Firstly, mixins are not very pythonic (inheritance is usually used for this), a little convoluted, and too "tight" (can only be used in the immediately adjacent _ext.py). Secondly, it (mixins) are now blocking are correct implementation of "value builders" (see here) where the problem becomes how to choose the correct base class that the value builder should call.

This PR looks big/complicated but appearances are deceiving; 4 things were needed to make this work:

  1. Drop skipDefaultBuilders in OpPythonBindingGen::emitDefaultOpBuilders
  2. Former mixin extension classes are converted to inherit from the generated OpView instead of being "mixins"
    a. extension classes that simply were calling into an already generated super().__init__ continue to do so
    b. (almost all) extension classes that were calling self.build_generic because of a lack of default builder being generated can now also just call super().__init__
  3. To handle the lone single use-case of select_opview_mixin, namely linalg, only a small change was necessary in opdsl/lang/emitter.py (thanks to the emission/generation of default builders/__init__s)
  4. since the extend_opview_class decorator is removed, we need a way to register extension classes as the desired OpView that op.opview conjures into existence; so we do the standard thing and just enable replacing the existing registered OpView i.e., register_operation(_Dialect, replace=True).

Note, the upgrade path for the common case is to change an extension to inherit from the generated builder and decorate it with register_operation(_Dialect, replace=True). In the slightly more complicated case where super().__init(self.build_generic(...)) is called in the extension's __init__, this needs to be updated to call __init__ in OpView, i.e., the grandparent (see updated docs).

Note, the PR has 3 base commits that look funny but this was done for the purpose of tracking the line history of moving the <DIALECT>_ops_ext.py class into <DIALECT>.py and updating (commit labeled "fix").

@makslevental makslevental changed the title Remove mixins 3 [mlir][python] remove mixins Oct 12, 2023
@github-actions
Copy link

github-actions bot commented Oct 12, 2023

✅ With the latest revision this PR passed the Python code formatter.

@makslevental makslevental force-pushed the remove_mixins_3 branch 3 times, most recently from 7ec1665 to d867537 Compare October 12, 2023 06:59
@makslevental makslevental marked this pull request as ready for review October 12, 2023 07:05
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:linalg mlir:python MLIR Python bindings mlir labels Oct 12, 2023
@llvmbot
Copy link
Member

llvmbot commented Oct 12, 2023

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir-core

Author: Maksim Levental (makslevental)

Changes

This PR replaces the mixin OpView extension mechanism with the standard inheritance mechanism.

Why? Firstly, it's not very pythonic (inheritance is usually used for this), a little convoluted, and too "tight" (can only be used in the immediately adjacent _ext.py). Secondly, it (mixins) are now blocking are correct implementation of "value builders" (see here) where the problem becomes how to choose the correct base class that the value builder should call.

This PR looks big/complicated but appearances are deceiving; 4 things were needed to be done to make this work:

  1. Drop skipDefaultBuilders in OpPythonBindingGen::emitDefaultOpBuilders.
  2. Former mixin extension classes are converted to inherit from the generated OpView instead of being "mixins"
    a. extension classes that simply were calling into an already generated super().__init__ continue to do so
    b. (almost all) extension classes that were calling self.build_generic because of a lack of default builder being generated can now also just call super().__init__
  3. To handle the lone single use-case of select_opview_mixin, namely linalg, only a small change was necessary in opdsl/lang/emitter.py (thanks to the emission/generation of default builders/__init__s)
  4. since the extend_opview_class decorator is removed, we need a way to register extension classes as the desired OpView that op.opview conjures into existence; so we do the standard thing and just enable replacing the existing registered OpView i.e., register_operation(_Dialect, replace=True).

Note, the PR has 3 commits that look funny but this was done for the purpose of tracking the line history of moving the &lt;DIALECT&gt;_ops_ext.py class into &lt;DIALECT&gt;.py and updating (commit labeled "fix").


Patch is 196.99 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/68853.diff

45 Files Affected:

  • (modified) mlir/lib/Bindings/Python/Globals.h (+2-2)
  • (modified) mlir/lib/Bindings/Python/IRModule.cpp (+2-2)
  • (modified) mlir/lib/Bindings/Python/MainModule.cpp (+6-5)
  • (modified) mlir/python/CMakeLists.txt (-18)
  • (removed) mlir/python/mlir/dialects/_arith_ops_ext.py (-69)
  • (removed) mlir/python/mlir/dialects/_bufferization_ops_ext.py (-41)
  • (removed) mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py (-128)
  • (removed) mlir/python/mlir/dialects/_builtin_ops_ext.py (-20)
  • (removed) mlir/python/mlir/dialects/_func_ops_ext.py (-319)
  • (removed) mlir/python/mlir/dialects/_gpu_transform_ops_ext.py (-124)
  • (removed) mlir/python/mlir/dialects/_linalg_ops_ext.py (-47)
  • (removed) mlir/python/mlir/dialects/_loop_transform_ops_ext.py (-134)
  • (removed) mlir/python/mlir/dialects/_memref_ops_ext.py (-36)
  • (removed) mlir/python/mlir/dialects/_memref_transform_ops_ext.py (-114)
  • (removed) mlir/python/mlir/dialects/_ml_program_ops_ext.py (-113)
  • (modified) mlir/python/mlir/dialects/_ods_common.py (-59)
  • (removed) mlir/python/mlir/dialects/_pdl_ops_ext.py (-271)
  • (removed) mlir/python/mlir/dialects/_scf_ops_ext.py (-107)
  • (removed) mlir/python/mlir/dialects/_structured_transform_ops_ext.py (-759)
  • (removed) mlir/python/mlir/dialects/_tensor_ops_ext.py (-44)
  • (removed) mlir/python/mlir/dialects/_tensor_transform_ops_ext.py (-64)
  • (removed) mlir/python/mlir/dialects/_transform_ops_ext.py (-176)
  • (removed) mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py (-55)
  • (modified) mlir/python/mlir/dialects/arith.py (+71)
  • (modified) mlir/python/mlir/dialects/bufferization.py (+36)
  • (modified) mlir/python/mlir/dialects/builtin.py (+20)
  • (modified) mlir/python/mlir/dialects/func.py (+323)
  • (modified) mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py (+1-1)
  • (modified) mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py (+58-49)
  • (modified) mlir/python/mlir/dialects/memref.py (+38)
  • (modified) mlir/python/mlir/dialects/ml_program.py (+114)
  • (modified) mlir/python/mlir/dialects/pdl.py (+285)
  • (modified) mlir/python/mlir/dialects/python_test.py (+6-1)
  • (modified) mlir/python/mlir/dialects/scf.py (+113-2)
  • (modified) mlir/python/mlir/dialects/tensor.py (+37)
  • (modified) mlir/python/mlir/dialects/transform/init.py (+170)
  • (modified) mlir/python/mlir/dialects/transform/bufferization.py (+129)
  • (modified) mlir/python/mlir/dialects/transform/gpu.py (+125)
  • (modified) mlir/python/mlir/dialects/transform/loop.py (+140)
  • (modified) mlir/python/mlir/dialects/transform/memref.py (+115)
  • (modified) mlir/python/mlir/dialects/transform/pdl.py (+50)
  • (modified) mlir/python/mlir/dialects/transform/structured.py (+773)
  • (modified) mlir/python/mlir/dialects/transform/tensor.py (+64)
  • (modified) mlir/python/mlir/runtime/np_to_memref.py (+5-3)
  • (modified) mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp (+1-9)
diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
index 97cd70089a2e965..21899bdce22e810 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -77,10 +77,10 @@ class PyGlobals {
                            pybind11::object pyClass);
 
   /// Adds a concrete implementation operation class.
-  /// Raises an exception if the mapping already exists.
+  /// Raises an exception if the mapping already exists and replace == false.
   /// This is intended to be called by implementation code.
   void registerOperationImpl(const std::string &operationName,
-                             pybind11::object pyClass);
+                             pybind11::object pyClass, bool replace = false);
 
   /// Returns the custom Attribute builder for Attribute kind.
   std::optional<pybind11::function>
diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp
index 2cc66277abee0f0..a1c8ab7a09ce155 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/IRModule.cpp
@@ -96,9 +96,9 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
 }
 
 void PyGlobals::registerOperationImpl(const std::string &operationName,
-                                      py::object pyClass) {
+                                      py::object pyClass, bool replace) {
   py::object &found = operationClassMap[operationName];
-  if (found) {
+  if (found && !replace) {
     throw std::runtime_error((llvm::Twine("Operation '") + operationName +
                               "' is already registered.")
                                  .str());
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index cdddfbe50606d05..a936becf67bea75 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -41,7 +41,7 @@ PYBIND11_MODULE(_mlir, m) {
            "dialect_namespace"_a, "dialect_class"_a,
            "Testing hook for directly registering a dialect")
       .def("_register_operation_impl", &PyGlobals::registerOperationImpl,
-           "operation_name"_a, "operation_class"_a,
+           "operation_name"_a, "operation_class"_a, "replace"_a = false,
            "Testing hook for directly registering an operation");
 
   // Aside from making the globals accessible to python, having python manage
@@ -63,12 +63,13 @@ PYBIND11_MODULE(_mlir, m) {
       "Class decorator for registering a custom Dialect wrapper");
   m.def(
       "register_operation",
-      [](const py::object &dialectClass) -> py::cpp_function {
+      [](const py::object &dialectClass, bool replace) -> py::cpp_function {
         return py::cpp_function(
-            [dialectClass](py::object opClass) -> py::object {
+            [dialectClass, replace](py::object opClass) -> py::object {
               std::string operationName =
                   opClass.attr("OPERATION_NAME").cast<std::string>();
-              PyGlobals::get().registerOperationImpl(operationName, opClass);
+              PyGlobals::get().registerOperationImpl(operationName, opClass,
+                                                     replace);
 
               // Dict-stuff the new opClass by name onto the dialect class.
               py::object opClassName = opClass.attr("__name__");
@@ -76,7 +77,7 @@ PYBIND11_MODULE(_mlir, m) {
               return opClass;
             });
       },
-      "dialect_class"_a,
+      "dialect_class"_a, "replace"_a = false,
       "Produce a class decorator for registering an Operation class as part of "
       "a dialect");
   m.def(
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 088d9a765b97730..2eff1cc7c588d8a 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -68,7 +68,6 @@ declare_mlir_dialect_python_bindings(
   TD_FILE dialects/BufferizationOps.td
   SOURCES
     dialects/bufferization.py
-    dialects/_bufferization_ops_ext.py
   DIALECT_NAME bufferization
   GEN_ENUM_BINDINGS_TD_FILE
     "../../include/mlir/Dialect/Bufferization/IR/BufferizationEnums.td"
@@ -80,7 +79,6 @@ declare_mlir_dialect_python_bindings(
   TD_FILE dialects/BuiltinOps.td
   SOURCES
     dialects/builtin.py
-    dialects/_builtin_ops_ext.py
   DIALECT_NAME builtin)
 
 declare_mlir_dialect_python_bindings(
@@ -105,7 +103,6 @@ declare_mlir_dialect_python_bindings(
   TD_FILE dialects/FuncOps.td
   SOURCES
     dialects/func.py
-    dialects/_func_ops_ext.py
   DIALECT_NAME func)
 
 declare_mlir_dialect_python_bindings(
@@ -121,7 +118,6 @@ declare_mlir_dialect_python_bindings(
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   TD_FILE dialects/LinalgOps.td
   SOURCES
-    dialects/_linalg_ops_ext.py
   SOURCES_GLOB
     dialects/linalg/*.py
   DIALECT_NAME linalg
@@ -142,7 +138,6 @@ ADD_TO_PARENT MLIRPythonSources.Dialects
 ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   TD_FILE dialects/TransformPDLExtensionOps.td
   SOURCES
-    dialects/_transform_pdl_extension_ops_ext.py
     dialects/transform/pdl.py
   DIALECT_NAME transform
   EXTENSION_NAME transform_pdl_extension)
@@ -152,7 +147,6 @@ declare_mlir_dialect_python_bindings(
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   TD_FILE dialects/TransformOps.td
   SOURCES
-    dialects/_transform_ops_ext.py
     dialects/transform/__init__.py
     _mlir_libs/_mlir/dialects/transform/__init__.pyi
   DIALECT_NAME transform
@@ -165,7 +159,6 @@ declare_mlir_dialect_extension_python_bindings(
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   TD_FILE dialects/BufferizationTransformOps.td
   SOURCES
-    dialects/_bufferization_transform_ops_ext.py
     dialects/transform/bufferization.py
   DIALECT_NAME transform
   EXTENSION_NAME bufferization_transform)
@@ -175,7 +168,6 @@ declare_mlir_dialect_extension_python_bindings(
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   TD_FILE dialects/GPUTransformOps.td
   SOURCES
-    dialects/_gpu_transform_ops_ext.py
     dialects/transform/gpu.py
   DIALECT_NAME transform
   EXTENSION_NAME gpu_transform)
@@ -185,7 +177,6 @@ declare_mlir_dialect_extension_python_bindings(
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   TD_FILE dialects/SCFLoopTransformOps.td
   SOURCES
-    dialects/_loop_transform_ops_ext.py
     dialects/transform/loop.py
   DIALECT_NAME transform
   EXTENSION_NAME loop_transform)
@@ -195,7 +186,6 @@ declare_mlir_dialect_extension_python_bindings(
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   TD_FILE dialects/MemRefTransformOps.td
   SOURCES
-    dialects/_memref_transform_ops_ext.py
     dialects/transform/memref.py
   DIALECT_NAME transform
   EXTENSION_NAME memref_transform)
@@ -214,7 +204,6 @@ declare_mlir_dialect_extension_python_bindings(
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   TD_FILE dialects/LinalgStructuredTransformOps.td
   SOURCES
-    dialects/_structured_transform_ops_ext.py
     dialects/transform/structured.py
   DIALECT_NAME transform
   EXTENSION_NAME structured_transform
@@ -236,7 +225,6 @@ declare_mlir_dialect_extension_python_bindings(
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   TD_FILE dialects/TensorTransformOps.td
   SOURCES
-    dialects/_tensor_transform_ops_ext.py
     dialects/transform/tensor.py
   DIALECT_NAME transform
   EXTENSION_NAME tensor_transform)
@@ -266,7 +254,6 @@ declare_mlir_dialect_python_bindings(
   TD_FILE dialects/ArithOps.td
   SOURCES
     dialects/arith.py
-    dialects/_arith_ops_ext.py
   DIALECT_NAME arith
   GEN_ENUM_BINDINGS)
 
@@ -276,7 +263,6 @@ declare_mlir_dialect_python_bindings(
   TD_FILE dialects/MemRefOps.td
   SOURCES
     dialects/memref.py
-    dialects/_memref_ops_ext.py
   DIALECT_NAME memref)
 
 declare_mlir_dialect_python_bindings(
@@ -285,7 +271,6 @@ declare_mlir_dialect_python_bindings(
   TD_FILE dialects/MLProgramOps.td
   SOURCES
     dialects/ml_program.py
-    dialects/_ml_program_ops_ext.py
   DIALECT_NAME ml_program)
 
 declare_mlir_dialect_python_bindings(
@@ -329,7 +314,6 @@ declare_mlir_dialect_python_bindings(
   TD_FILE dialects/PDLOps.td
   SOURCES
     dialects/pdl.py
-    dialects/_pdl_ops_ext.py
     _mlir_libs/_mlir/dialects/pdl.pyi
   DIALECT_NAME pdl)
 
@@ -347,7 +331,6 @@ declare_mlir_dialect_python_bindings(
   TD_FILE dialects/SCFOps.td
   SOURCES
     dialects/scf.py
-    dialects/_scf_ops_ext.py
   DIALECT_NAME scf)
 
 declare_mlir_dialect_python_bindings(
@@ -373,7 +356,6 @@ declare_mlir_dialect_python_bindings(
   TD_FILE dialects/TensorOps.td
   SOURCES
     dialects/tensor.py
-    dialects/_tensor_ops_ext.py
   DIALECT_NAME tensor)
 
 declare_mlir_dialect_python_bindings(
diff --git a/mlir/python/mlir/dialects/_arith_ops_ext.py b/mlir/python/mlir/dialects/_arith_ops_ext.py
deleted file mode 100644
index df38f871710fe8f..000000000000000
--- a/mlir/python/mlir/dialects/_arith_ops_ext.py
+++ /dev/null
@@ -1,69 +0,0 @@
-#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-#  See https://llvm.org/LICENSE.txt for license information.
-#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
-    from ..ir import *
-    from ._ods_common import get_default_loc_context as _get_default_loc_context
-
-    from typing import Any, List, Union
-except ImportError as e:
-    raise RuntimeError("Error loading imports from extension module") from e
-
-
-def _isa(obj: Any, cls: type):
-    try:
-        cls(obj)
-    except ValueError:
-        return False
-    return True
-
-
-def _is_any_of(obj: Any, classes: List[type]):
-    return any(_isa(obj, cls) for cls in classes)
-
-
-def _is_integer_like_type(type: Type):
-    return _is_any_of(type, [IntegerType, IndexType])
-
-
-def _is_float_type(type: Type):
-    return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])
-
-
-class ConstantOp:
-    """Specialization for the constant op class."""
-
-    def __init__(
-        self, result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
-    ):
-        if isinstance(value, int):
-            super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
-        elif isinstance(value, float):
-            super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
-        else:
-            super().__init__(value, loc=loc, ip=ip)
-
-    @classmethod
-    def create_index(cls, value: int, *, loc=None, ip=None):
-        """Create an index-typed constant."""
-        return cls(
-            IndexType.get(context=_get_default_loc_context(loc)), value, loc=loc, ip=ip
-        )
-
-    @property
-    def type(self):
-        return self.results[0].type
-
-    @property
-    def value(self):
-        return Attribute(self.operation.attributes["value"])
-
-    @property
-    def literal_value(self) -> Union[int, float]:
-        if _is_integer_like_type(self.type):
-            return IntegerAttr(self.value).value
-        elif _is_float_type(self.type):
-            return FloatAttr(self.value).value
-        else:
-            raise ValueError("only integer and float constants have literal values")
diff --git a/mlir/python/mlir/dialects/_bufferization_ops_ext.py b/mlir/python/mlir/dialects/_bufferization_ops_ext.py
deleted file mode 100644
index 1066cb4c775cab9..000000000000000
--- a/mlir/python/mlir/dialects/_bufferization_ops_ext.py
+++ /dev/null
@@ -1,41 +0,0 @@
-#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-#  See https://llvm.org/LICENSE.txt for license information.
-#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
-    from typing import Sequence, Union
-    from ..ir import *
-    from ._ods_common import get_default_loc_context
-
-    from typing import Any, List, Union
-except ImportError as e:
-    raise RuntimeError("Error loading imports from extension module") from e
-
-
-class AllocTensorOp:
-    """Extends the bufferization.alloc_tensor op."""
-
-    def __init__(
-        self,
-        tensor_type: Type,
-        dynamic_sizes: Sequence[Value],
-        copy: Value,
-        size_hint: Value,
-        escape: BoolAttr,
-        *,
-        loc=None,
-        ip=None
-    ):
-        """Constructs an `alloc_tensor` with static and/or dynamic sizes."""
-        context = get_default_loc_context(loc)
-        attributes = {}
-        if escape:
-            attributes["escape"] = escape
-        op = self.build_generic(
-            results=[tensor_type],
-            operands=[dynamic_sizes, copy, size_hint],
-            attributes=attributes,
-            loc=loc,
-            ip=ip,
-        )
-        OpView.__init__(self, op)
diff --git a/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py b/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py
deleted file mode 100644
index 7e6c1b81cb350b7..000000000000000
--- a/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py
+++ /dev/null
@@ -1,128 +0,0 @@
-#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-#  See https://llvm.org/LICENSE.txt for license information.
-#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
-    from ..ir import *
-    from ..dialects import transform
-except ImportError as e:
-    raise RuntimeError("Error loading imports from extension module") from e
-
-from enum import Enum
-from typing import Optional, overload, Union
-
-
-class EmptyTensorToAllocTensorOp:
-    """Specialization for EmptyTensorToAllocTensorOp class."""
-
-    @overload
-    def __init__(
-        self,
-        transformed_type: Type,
-        target: Union[Operation, OpView, Value],
-        *,
-        loc=None,
-        ip=None
-    ):
-        ...
-
-    @overload
-    def __init__(self, target: Union[Operation, OpView, Value], *, loc=None, ip=None):
-        ...
-
-    def __init__(
-        self,
-        transformed_type_or_target: Type,
-        target_or_none: Optional[Union[Operation, OpView, Value]] = None,
-        *,
-        loc=None,
-        ip=None
-    ):
-        if isinstance(transformed_type_or_target, Type):
-            transformed_type = transformed_type_or_target
-            target = target_or_none
-        else:
-            transformed_type = transform.OperationType.get("bufferization.alloc_tensor")
-            target = transformed_type_or_target
-
-        super().__init__(
-            transformed_type,
-            target,
-            loc=loc,
-            ip=ip,
-        )
-
-
-class OneShotBufferizeOp:
-    """Specialization for OneShotBufferizeOp class."""
-
-    @overload
-    def __init__(
-        self,
-        transformed_type: Type,
-        target: Union[Operation, OpView, Value],
-        *,
-        allow_return_allocs_from_loops: Optional[bool] = None,
-        allow_unknown_ops: Optional[bool] = None,
-        bufferize_function_boundaries: Optional[bool] = None,
-        function_boundary_type_conversion: Optional[Enum] = None,
-        memcpy_op: Optional[str] = None,
-        print_conflicts: Optional[bool] = None,
-        test_analysis_only: Optional[bool] = None,
-        loc=None,
-        ip=None
-    ):
-        ...
-
-    @overload
-    def __init__(
-        self,
-        target: Union[Operation, OpView, Value],
-        *,
-        allow_return_allocs_from_loops: Optional[bool] = None,
-        allow_unknown_ops: Optional[bool] = None,
-        bufferize_function_boundaries: Optional[bool] = None,
-        function_boundary_type_conversion: Optional[Enum] = None,
-        memcpy_op: Optional[str] = None,
-        print_conflicts: Optional[bool] = None,
-        test_analysis_only: Optional[bool] = None,
-        loc=None,
-        ip=None
-    ):
-        ...
-
-    def __init__(
-        self,
-        transformed_type_or_target: Type,
-        target_or_none: Optional[Union[Operation, OpView, Value]] = None,
-        *,
-        allow_return_allocs_from_loops: Optional[bool] = None,
-        allow_unknown_ops: Optional[bool] = None,
-        bufferize_function_boundaries: Optional[bool] = None,
-        function_boundary_type_conversion: Optional[Enum] = None,
-        memcpy_op: Optional[str] = None,
-        print_conflicts: Optional[bool] = None,
-        test_analysis_only: Optional[bool] = None,
-        loc=None,
-        ip=None
-    ):
-        if isinstance(transformed_type_or_target, Type):
-            transformed_type = transformed_type_or_target
-            target = target_or_none
-        else:
-            transformed_type = transform.AnyOpType.get()
-            target = transformed_type_or_target
-
-        super().__init__(
-            transformed_type,
-            target,
-            allow_return_allocs_from_loops=allow_return_allocs_from_loops,
-            allow_unknown_ops=allow_unknown_ops,
-            bufferize_function_boundaries=bufferize_function_boundaries,
-            function_boundary_type_conversion=function_boundary_type_conversion,
-            memcpy_op=memcpy_op,
-            print_conflicts=print_conflicts,
-            test_analysis_only=test_analysis_only,
-            loc=loc,
-            ip=ip,
-        )
diff --git a/mlir/python/mlir/dialects/_builtin_ops_ext.py b/mlir/python/mlir/dialects/_builtin_ops_ext.py
deleted file mode 100644
index 27a60123050acb4..000000000000000
--- a/mlir/python/mlir/dialects/_builtin_ops_ext.py
+++ /dev/null
@@ -1,20 +0,0 @@
-#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-#  See https://llvm.org/LICENSE.txt for license information.
-#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
-    from ..ir import *
-except ImportError as e:
-    raise RuntimeError("Error loading imports from extension module") from e
-
-
-class ModuleOp:
-    """Specialization for the module op class."""
-
-    def __init__(self, *, loc=None, ip=None):
-        super().__init__(self.build_generic(results=[], operands=[], loc=loc, ip=ip))
-        body = self.regions[0].blocks.append()
-
-    @property
-    def body(self):
-        return self.regions[0].blocks[0]
diff --git a/mlir/python/mlir/dialects/_func_ops_ext.py b/mlir/python/mlir/dialects/_func_ops_ext.py
deleted file mode 100644
index 6d264c33f1f9dae..000000000000000
--- a/mlir/python/mlir/dialects/_func_ops_ext.py
+++ /dev/null
@@ -1,319 +0,0 @@
-#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-#  See https://llvm.org/LICENSE.txt for license information.
-#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
-    from ..ir import *
-    from ._ods_common import get_default_loc_context as _get_default_loc_context
-
-    import inspect
-
-    from typing import Any, List, Optional, Sequence, Union
-except ImportError as e:
-    raise RuntimeError("Error loading imports from extension module") from e
-
-ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
-RESULT_ATTRIBUTE_NAME = "res_attrs"
-
-
-class ConstantOp:
-    """Specialization for the constant op class."""
-
-    def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None):
-        super().__init__(result, value, loc=loc, ip=ip)
-
-    @property
-    def type(self):
-        return self.results[0].type
-
-
-class FuncOp:
-    """Specialization for the func op class."""
-
-    def __init__(
-        self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None
-    ):
-        """
-        Create a FuncOp with the provided `name`, `type`, and `visibility`.
-        - `name` is a string representing the function name.
-        - `type` is either a FunctionType or a pair of list describing inputs and
-          results.
-        - `visibility` is a string matching `public`, `private`, or `nested`. None
-          implies private visibility.
-        - `body_builder` is an optional callback, when provided a new entry block
-          is created and the callback is invoked with the new op as argument within
-          an InsertionPoint context already set for the block. The callback is
-          expected to insert a terminator in the block.
-        """
-        sym_name = StringAttr.get(str(name))
-
-        # If the type is passed as a tuple, build a FunctionType on the fly.
-        if isinstance(type, tuple):
-            type = FunctionType.get(inputs=type[0], results=type[1])
-
-        type = TypeAttr.get(type)
-        sym_visibility = (
-            StringAttr.get(str(visibility)) if visibility is not None else None
-        )
-        super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip...
[truncated]

@stellaraccident
Copy link
Contributor

Thank you. I need to review this carefully and can't do that tonight, but I've regretted this design decision since it was made. Thanks for taking the time to improve it.

@makslevental makslevental force-pushed the remove_mixins_3 branch 3 times, most recently from 2ad08c1 to d5a8f42 Compare October 12, 2023 14:54
Copy link
Member

@rkayaith rkayaith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, this seems a lot easier to understand/follow imo

return self.attributes["sym_visibility"]

@property
def name(self) -> StringAttr:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I trust Stella and your judgement on the mixins ... But this also feels a bit fragile for me. We don't have a single source of truth now for these properties. Not even anything that could flag it.

Copy link
Contributor Author

@makslevental makslevental Oct 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are indeed fragile I agree but they're also opt-in (to some extent): you can always import the <DIALECT>_ops_gen.py file directly and skip registering these (and we should be diligent about checking whether they're internally consumed/called).

Additionally we could defend against drift by making sure each such property is exercised in a test (I'm hoping most already are...).

But adding some of these to the generated builders should be straightforward - I'll take a shot at it soon.

Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this seems like an improvement. I'd love for a way to avoid things accidentally getting out of sync. I can't say I have a satisfying answer for that ... it seems it is mostly just the "constructor" that is modified and potentially some extra class members, but don't know if there is mechanism to make that easy while ensuring SOT is ODS still. Well and I'd want overwrites in python to be the way to do so, but have it be minimal, non-magical (e.g., the mixin file approach is a bit magical) but also avoid footgun if operands etc change.

@makslevental
Copy link
Contributor Author

it seems it is mostly just the "constructor" that is modified and potentially some extra class members, but don't know if there is mechanism to make that easy while ensuring SOT is ODS still.

One thing to consider, if this is a concern, is to drop registering the mixins as the default op.opview. Then they could break in an of themselves (which would be embarrassing) but that breakage would be localized a single instantiation.

@ftynse
Copy link
Member

ftynse commented Oct 19, 2023

This looks like an improvement over the previous state, at least from my better understanding of how this works. Maybe we can figure out a mechanism that checks whether the FooBarOp derived classes are having the required decorators to avoid the footgun.

As for handling changes to ODS op specification, this doesn't make it any worse. In the existing approach, we would have also needed to update the mixin classes when ODS changed (there have been several cases recently). I think the best way to guard against this is to keep improving the generated API (e.g., generate overloads) so that the amount of handwritten builders becomes minimal.

@makslevental makslevental merged commit a2288a8 into llvm:main Oct 19, 2023
@makslevental makslevental deleted the remove_mixins_3 branch October 19, 2023 21:20
@teqdruid teqdruid mentioned this pull request Oct 19, 2023
makslevental added a commit that referenced this pull request Oct 19, 2023
#68853 enabled a lot of nice
cleanup. Note, I made sure each of the touched extensions had tests.
copybara-service bot pushed a commit to jax-ml/jax that referenced this pull request Oct 23, 2023
The way MLIR dialects are allowed to be extended in Python has recently
changed (in llvm/llvm-project#68853), so we have
to update our bindings.

PiperOrigin-RevId: 575796552
copybara-service bot pushed a commit to jax-ml/jax that referenced this pull request Oct 23, 2023
The way MLIR dialects are allowed to be extended in Python has recently
changed (in llvm/llvm-project#68853), so we have
to update our bindings.

PiperOrigin-RevId: 575796552
hawkinsp added a commit to hawkinsp/jax that referenced this pull request Oct 23, 2023
llvm/llvm-project#68853 changed the structure of
the upstream MLIR Python bindings, breaking the jaxlib build. Update our
build scripts to match.
@ingomueller-net
Copy link
Contributor

As for handling changes to ODS op specification, this doesn't make it any worse. In the existing approach, we would have also needed to update the mixin classes when ODS changed (there have been several cases recently). I think the best way to guard against this is to keep improving the generated API (e.g., generate overloads) so that the amount of handwritten builders becomes minimal.

I agree that removing/reducing the need to write mix-ins/extensions should remove/reduce the test surface and thus this issue. For the remainder, I thought about a way to test this some time ago: https://reviews.llvm.org/D159100. I can revive that if there is interest/consensus.

copybara-service bot pushed a commit to jax-ml/jax that referenced this pull request Oct 24, 2023
The way MLIR dialects are allowed to be extended in Python has recently
changed (in llvm/llvm-project#68853), so we have
to update our bindings.

PiperOrigin-RevId: 575796552
copybara-service bot pushed a commit to jax-ml/jax that referenced this pull request Oct 24, 2023
The way MLIR dialects are allowed to be extended in Python has recently
changed (in llvm/llvm-project#68853), so we have
to update our bindings.

PiperOrigin-RevId: 575796552
copybara-service bot pushed a commit to jax-ml/jax that referenced this pull request Oct 24, 2023
The way MLIR dialects are allowed to be extended in Python has recently
changed (in llvm/llvm-project#68853), so we have
to update our bindings.

PiperOrigin-RevId: 575796552
copybara-service bot pushed a commit to jax-ml/jax that referenced this pull request Oct 24, 2023
The way MLIR dialects are allowed to be extended in Python has recently
changed (in llvm/llvm-project#68853), so we have
to update our bindings.

PiperOrigin-RevId: 576060814
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:linalg mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants