Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][python]Add sugared buider for transform.named_sequence #71597

Merged
merged 1 commit into from
Nov 8, 2023

Conversation

nicolasvasilache
Copy link
Contributor

No description provided.

@llvmbot
Copy link
Collaborator

llvmbot commented Nov 7, 2023

@llvm/pr-subscribers-mlir

Author: Nicolas Vasilache (nicolasvasilache)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/71597.diff

2 Files Affected:

  • (modified) mlir/python/mlir/dialects/transform/init.py (+28)
  • (modified) mlir/test/python/dialects/transform.py (+19-99)
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 166c5c5ca4ec344..23b278d374332b5 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -165,6 +165,34 @@ def bodyExtraArgs(self) -> BlockArgumentList:
         return self.body.arguments[1:]
 
 
+@_ods_cext.register_operation(_Dialect, replace=True)
+class NamedSequenceOp(NamedSequenceOp):
+    def __init__(
+        self,
+        sym_name,
+        input_types: Sequence[Type],
+        result_types: Sequence[Type],
+    ):
+        function_type = FunctionType.get(input_types, result_types)
+        super().__init__(
+            sym_name=sym_name,
+            function_type=TypeAttr.get(function_type),
+        )
+        self.regions[0].blocks.append(*input_types)
+
+    @property
+    def body(self) -> Block:
+        return self.regions[0].blocks[0]
+
+    @property
+    def bodyTarget(self) -> Value:
+        return self.body.arguments[0]
+
+    @property
+    def bodyExtraArgs(self) -> BlockArgumentList:
+        return self.body.arguments[1:]
+
+
 @_ods_cext.register_operation(_Dialect, replace=True)
 class YieldOp(YieldOp):
     def __init__(
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index d778172a607a360..e7f448850a66aa1 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -10,13 +10,13 @@ def run(f):
         module = Module.create()
         with InsertionPoint(module.body):
             print("\nTEST:", f.__name__)
-            f()
+            f(module)
         print(module)
     return f
 
 
 @run
-def testTypes():
+def testTypes(module: Module):
     # CHECK-LABEL: TEST: testTypes
     # CHECK: !transform.any_op
     any_op = transform.AnyOpType.get()
@@ -44,7 +44,7 @@ def testTypes():
 
 
 @run
-def testSequenceOp():
+def testSequenceOp(module: Module):
     sequence = transform.SequenceOp(
         transform.FailurePropagationMode.Propagate,
         [transform.AnyOpType.get()],
@@ -60,103 +60,23 @@ def testSequenceOp():
 
 
 @run
-def testNestedSequenceOp():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
-    )
-    with InsertionPoint(sequence.body):
-        nested = transform.SequenceOp(
-            transform.FailurePropagationMode.Propagate, [], sequence.bodyTarget
-        )
-        with InsertionPoint(nested.body):
-            doubly_nested = transform.SequenceOp(
-                transform.FailurePropagationMode.Propagate,
-                [transform.AnyOpType.get()],
-                nested.bodyTarget,
-            )
-            with InsertionPoint(doubly_nested.body):
-                transform.YieldOp([doubly_nested.bodyTarget])
-            transform.YieldOp()
-        transform.YieldOp()
-    # CHECK-LABEL: TEST: testNestedSequenceOp
-    # CHECK: transform.sequence failures(propagate) {
-    # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
-    # CHECK:   sequence %[[ARG0]] : !transform.any_op failures(propagate) {
-    # CHECK:   ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
-    # CHECK:     = sequence %[[ARG1]] : !transform.any_op -> !transform.any_op failures(propagate) {
-    # CHECK:     ^{{.*}}(%[[ARG2:.+]]: !transform.any_op):
-    # CHECK:       yield %[[ARG2]] : !transform.any_op
-    # CHECK:     }
-    # CHECK:   }
-    # CHECK: }
-
-
-@run
-def testSequenceOpWithExtras():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate,
-        [],
-        transform.AnyOpType.get(),
-        [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
-    )
-    with InsertionPoint(sequence.body):
-        transform.YieldOp()
-    # CHECK-LABEL: TEST: testSequenceOpWithExtras
-    # CHECK: transform.sequence failures(propagate)
-    # CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">):
-
-
-@run
-def testNestedSequenceOpWithExtras():
-  sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate,
-        [],
-        transform.AnyOpType.get(),
-        [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
-    )
-  with InsertionPoint(sequence.body):
-    nested = transform.SequenceOp(
-            transform.FailurePropagationMode.Propagate,
-            [],
-            sequence.bodyTarget,
-            sequence.bodyExtraArgs,
-        )
-    with InsertionPoint(nested.body):
-      transform.YieldOp()
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras
-  # CHECK: transform.sequence failures(propagate)
-  # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">):
-  # CHECK:   sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">)
-
-
-@run
-def testTransformPDLOps():
-  withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
-  with InsertionPoint(withPdl.body):
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate,
+def testNamedSequenceOp(module: Module):
+    module.operation.attributes["transform.with_named_sequence"] = UnitAttr.get()
+    named_sequence = transform.NamedSequenceOp(
+        '__transform_main',
+        [transform.AnyOpType.get()],
         [transform.AnyOpType.get()],
-        withPdl.bodyTarget,
     )
-    with InsertionPoint(sequence.body):
-      match = transform_pdl.PDLMatchOp(
-          transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher"
-      )
-      transform.YieldOp(match)
-  # CHECK-LABEL: TEST: testTransformPDLOps
-  # CHECK: transform.with_pdl_patterns {
-  # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
-  # CHECK:   = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) {
-  # CHECK:   ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
-  # CHECK:     %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]]
-  # CHECK:     yield %[[RES]] : !transform.any_op
-  # CHECK:   }
-  # CHECK: }
+    with InsertionPoint(named_sequence.body):
+        transform.YieldOp([named_sequence.bodyTarget])
+    # CHECK-LABEL: TEST: testNamedSequenceOp
+    # CHECK: module attributes {transform.with_named_sequence} {
+    # CHECK: transform.named_sequence @__transform_main(%[[ARG0:.+]]: !transform.any_op) -> !transform.any_op {
+    # CHECK:   yield %[[ARG0]] : !transform.any_op
 
 
 @run
-def testGetParentOp():
+def testGetParentOp(module: Module):
   sequence = transform.SequenceOp(
       transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
   )
@@ -175,7 +95,7 @@ def testGetParentOp():
 
 
 @run
-def testMergeHandlesOp():
+def testMergeHandlesOp(module: Module):
     sequence = transform.SequenceOp(
         transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
     )
@@ -189,7 +109,7 @@ def testMergeHandlesOp():
 
 
 @run
-def testApplyPatternsOpCompact():
+def testApplyPatternsOpCompact(module: Module):
   sequence = transform.SequenceOp(
       transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
   )
@@ -204,7 +124,7 @@ def testApplyPatternsOpCompact():
 
 
 @run
-def testApplyPatternsOpWithType():
+def testApplyPatternsOpWithType(module: Module):
   sequence = transform.SequenceOp(
       transform.FailurePropagationMode.Propagate, [],
       transform.OperationType.get('test.dummy')
@@ -220,7 +140,7 @@ def testApplyPatternsOpWithType():
 
 
 @run
-def testReplicateOp():
+def testReplicateOp(module: Module):
     with_pdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
     with InsertionPoint(with_pdl.body):
         sequence = transform.SequenceOp(

Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM (TODO for myself: figure out the anon attribute thing).

mlir/test/python/dialects/transform.py Show resolved Hide resolved
Copy link

github-actions bot commented Nov 7, 2023

⚠️ Python code formatter, darker found issues in your code. ⚠️

You can test this locally with the following command:
darker --check --diff -r 25ec1fa969a0d13f440222f575277f9601eaea76..0e382dd88c7ae6d07da35f9a7312dbfd9850b630 mlir/python/mlir/dialects/transform/__init__.py mlir/test/python/dialects/transform.py
View the diff from darker here.
--- test/python/dialects/transform.py	2023-11-07 22:54:30.000000 +0000
+++ test/python/dialects/transform.py	2023-11-07 23:05:23.716448 +0000
@@ -61,11 +61,11 @@
 
 @run
 def testNamedSequenceOp(module: Module):
     module.operation.attributes["transform.with_named_sequence"] = UnitAttr.get()
     named_sequence = transform.NamedSequenceOp(
-        '__transform_main',
+        "__transform_main",
         [transform.AnyOpType.get()],
         [transform.AnyOpType.get()],
     )
     with InsertionPoint(named_sequence.body):
         transform.YieldOp([named_sequence.bodyTarget])

@nicolasvasilache nicolasvasilache merged commit 4f51b2b into llvm:main Nov 8, 2023
4 of 5 checks passed
nicolasvasilache added a commit that referenced this pull request Nov 8, 2023
Comment on lines -63 to -155
def testNestedSequenceOp():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
with InsertionPoint(sequence.body):
nested = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], sequence.bodyTarget
)
with InsertionPoint(nested.body):
doubly_nested = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[transform.AnyOpType.get()],
nested.bodyTarget,
)
with InsertionPoint(doubly_nested.body):
transform.YieldOp([doubly_nested.bodyTarget])
transform.YieldOp()
transform.YieldOp()
# CHECK-LABEL: TEST: testNestedSequenceOp
# CHECK: transform.sequence failures(propagate) {
# CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
# CHECK: sequence %[[ARG0]] : !transform.any_op failures(propagate) {
# CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
# CHECK: = sequence %[[ARG1]] : !transform.any_op -> !transform.any_op failures(propagate) {
# CHECK: ^{{.*}}(%[[ARG2:.+]]: !transform.any_op):
# CHECK: yield %[[ARG2]] : !transform.any_op
# CHECK: }
# CHECK: }
# CHECK: }


@run
def testSequenceOpWithExtras():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.AnyOpType.get(),
[transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
)
with InsertionPoint(sequence.body):
transform.YieldOp()
# CHECK-LABEL: TEST: testSequenceOpWithExtras
# CHECK: transform.sequence failures(propagate)
# CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">):


@run
def testNestedSequenceOpWithExtras():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.AnyOpType.get(),
[transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
)
with InsertionPoint(sequence.body):
nested = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
sequence.bodyTarget,
sequence.bodyExtraArgs,
)
with InsertionPoint(nested.body):
transform.YieldOp()
transform.YieldOp()
# CHECK-LABEL: TEST: testNestedSequenceOpWithExtras
# CHECK: transform.sequence failures(propagate)
# CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">):
# CHECK: sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">)


@run
def testTransformPDLOps():
withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
with InsertionPoint(withPdl.body):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
def testNamedSequenceOp(module: Module):
module.operation.attributes["transform.with_named_sequence"] = UnitAttr.get()
named_sequence = transform.NamedSequenceOp(
'__transform_main',
[transform.AnyOpType.get()],
[transform.AnyOpType.get()],
withPdl.bodyTarget,
)
with InsertionPoint(sequence.body):
match = transform_pdl.PDLMatchOp(
transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher"
)
transform.YieldOp(match)
# CHECK-LABEL: TEST: testTransformPDLOps
# CHECK: transform.with_pdl_patterns {
# CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
# CHECK: = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) {
# CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
# CHECK: %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]]
# CHECK: yield %[[RES]] : !transform.any_op
# CHECK: }
# CHECK: }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are all these deleted? These are tests for Python API of ops that still exist, and these ops are not in any way affected by this change. Please revert deletion!

Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would have requested changes here had I been given the chance. This shouldn't just decrease test coverage without justification.

@nicolasvasilache
Copy link
Contributor Author

I would have requested changes here had I been given the chance. This shouldn't just decrease test coverage without justification.

sorry, bad debug state landed, on it

nicolasvasilache added a commit to nicolasvasilache/llvm-project that referenced this pull request Nov 8, 2023
nicolasvasilache added a commit to nicolasvasilache/llvm-project that referenced this pull request Nov 8, 2023
nicolasvasilache added a commit that referenced this pull request Nov 8, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants