-
Notifications
You must be signed in to change notification settings - Fork 11.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][python]Add sugared buider for transform.named_sequence #71597
Conversation
@llvm/pr-subscribers-mlir Author: Nicolas Vasilache (nicolasvasilache) ChangesFull diff: https://github.com/llvm/llvm-project/pull/71597.diff 2 Files Affected:
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 166c5c5ca4ec344..23b278d374332b5 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -165,6 +165,34 @@ def bodyExtraArgs(self) -> BlockArgumentList:
return self.body.arguments[1:]
+@_ods_cext.register_operation(_Dialect, replace=True)
+class NamedSequenceOp(NamedSequenceOp):
+ def __init__(
+ self,
+ sym_name,
+ input_types: Sequence[Type],
+ result_types: Sequence[Type],
+ ):
+ function_type = FunctionType.get(input_types, result_types)
+ super().__init__(
+ sym_name=sym_name,
+ function_type=TypeAttr.get(function_type),
+ )
+ self.regions[0].blocks.append(*input_types)
+
+ @property
+ def body(self) -> Block:
+ return self.regions[0].blocks[0]
+
+ @property
+ def bodyTarget(self) -> Value:
+ return self.body.arguments[0]
+
+ @property
+ def bodyExtraArgs(self) -> BlockArgumentList:
+ return self.body.arguments[1:]
+
+
@_ods_cext.register_operation(_Dialect, replace=True)
class YieldOp(YieldOp):
def __init__(
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index d778172a607a360..e7f448850a66aa1 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -10,13 +10,13 @@ def run(f):
module = Module.create()
with InsertionPoint(module.body):
print("\nTEST:", f.__name__)
- f()
+ f(module)
print(module)
return f
@run
-def testTypes():
+def testTypes(module: Module):
# CHECK-LABEL: TEST: testTypes
# CHECK: !transform.any_op
any_op = transform.AnyOpType.get()
@@ -44,7 +44,7 @@ def testTypes():
@run
-def testSequenceOp():
+def testSequenceOp(module: Module):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[transform.AnyOpType.get()],
@@ -60,103 +60,23 @@ def testSequenceOp():
@run
-def testNestedSequenceOp():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
- )
- with InsertionPoint(sequence.body):
- nested = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], sequence.bodyTarget
- )
- with InsertionPoint(nested.body):
- doubly_nested = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate,
- [transform.AnyOpType.get()],
- nested.bodyTarget,
- )
- with InsertionPoint(doubly_nested.body):
- transform.YieldOp([doubly_nested.bodyTarget])
- transform.YieldOp()
- transform.YieldOp()
- # CHECK-LABEL: TEST: testNestedSequenceOp
- # CHECK: transform.sequence failures(propagate) {
- # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
- # CHECK: sequence %[[ARG0]] : !transform.any_op failures(propagate) {
- # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
- # CHECK: = sequence %[[ARG1]] : !transform.any_op -> !transform.any_op failures(propagate) {
- # CHECK: ^{{.*}}(%[[ARG2:.+]]: !transform.any_op):
- # CHECK: yield %[[ARG2]] : !transform.any_op
- # CHECK: }
- # CHECK: }
- # CHECK: }
-
-
-@run
-def testSequenceOpWithExtras():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate,
- [],
- transform.AnyOpType.get(),
- [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
- )
- with InsertionPoint(sequence.body):
- transform.YieldOp()
- # CHECK-LABEL: TEST: testSequenceOpWithExtras
- # CHECK: transform.sequence failures(propagate)
- # CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">):
-
-
-@run
-def testNestedSequenceOpWithExtras():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate,
- [],
- transform.AnyOpType.get(),
- [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
- )
- with InsertionPoint(sequence.body):
- nested = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate,
- [],
- sequence.bodyTarget,
- sequence.bodyExtraArgs,
- )
- with InsertionPoint(nested.body):
- transform.YieldOp()
- transform.YieldOp()
- # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras
- # CHECK: transform.sequence failures(propagate)
- # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">):
- # CHECK: sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">)
-
-
-@run
-def testTransformPDLOps():
- withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
- with InsertionPoint(withPdl.body):
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate,
+def testNamedSequenceOp(module: Module):
+ module.operation.attributes["transform.with_named_sequence"] = UnitAttr.get()
+ named_sequence = transform.NamedSequenceOp(
+ '__transform_main',
+ [transform.AnyOpType.get()],
[transform.AnyOpType.get()],
- withPdl.bodyTarget,
)
- with InsertionPoint(sequence.body):
- match = transform_pdl.PDLMatchOp(
- transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher"
- )
- transform.YieldOp(match)
- # CHECK-LABEL: TEST: testTransformPDLOps
- # CHECK: transform.with_pdl_patterns {
- # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
- # CHECK: = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) {
- # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
- # CHECK: %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]]
- # CHECK: yield %[[RES]] : !transform.any_op
- # CHECK: }
- # CHECK: }
+ with InsertionPoint(named_sequence.body):
+ transform.YieldOp([named_sequence.bodyTarget])
+ # CHECK-LABEL: TEST: testNamedSequenceOp
+ # CHECK: module attributes {transform.with_named_sequence} {
+ # CHECK: transform.named_sequence @__transform_main(%[[ARG0:.+]]: !transform.any_op) -> !transform.any_op {
+ # CHECK: yield %[[ARG0]] : !transform.any_op
@run
-def testGetParentOp():
+def testGetParentOp(module: Module):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
@@ -175,7 +95,7 @@ def testGetParentOp():
@run
-def testMergeHandlesOp():
+def testMergeHandlesOp(module: Module):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
@@ -189,7 +109,7 @@ def testMergeHandlesOp():
@run
-def testApplyPatternsOpCompact():
+def testApplyPatternsOpCompact(module: Module):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
@@ -204,7 +124,7 @@ def testApplyPatternsOpCompact():
@run
-def testApplyPatternsOpWithType():
+def testApplyPatternsOpWithType(module: Module):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [],
transform.OperationType.get('test.dummy')
@@ -220,7 +140,7 @@ def testApplyPatternsOpWithType():
@run
-def testReplicateOp():
+def testReplicateOp(module: Module):
with_pdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
with InsertionPoint(with_pdl.body):
sequence = transform.SequenceOp(
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM (TODO for myself: figure out the anon attribute thing).
You can test this locally with the following command:darker --check --diff -r 25ec1fa969a0d13f440222f575277f9601eaea76..0e382dd88c7ae6d07da35f9a7312dbfd9850b630 mlir/python/mlir/dialects/transform/__init__.py mlir/test/python/dialects/transform.py View the diff from darker here.--- test/python/dialects/transform.py 2023-11-07 22:54:30.000000 +0000
+++ test/python/dialects/transform.py 2023-11-07 23:05:23.716448 +0000
@@ -61,11 +61,11 @@
@run
def testNamedSequenceOp(module: Module):
module.operation.attributes["transform.with_named_sequence"] = UnitAttr.get()
named_sequence = transform.NamedSequenceOp(
- '__transform_main',
+ "__transform_main",
[transform.AnyOpType.get()],
[transform.AnyOpType.get()],
)
with InsertionPoint(named_sequence.body):
transform.YieldOp([named_sequence.bodyTarget])
|
def testNestedSequenceOp(): | ||
sequence = transform.SequenceOp( | ||
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() | ||
) | ||
with InsertionPoint(sequence.body): | ||
nested = transform.SequenceOp( | ||
transform.FailurePropagationMode.Propagate, [], sequence.bodyTarget | ||
) | ||
with InsertionPoint(nested.body): | ||
doubly_nested = transform.SequenceOp( | ||
transform.FailurePropagationMode.Propagate, | ||
[transform.AnyOpType.get()], | ||
nested.bodyTarget, | ||
) | ||
with InsertionPoint(doubly_nested.body): | ||
transform.YieldOp([doubly_nested.bodyTarget]) | ||
transform.YieldOp() | ||
transform.YieldOp() | ||
# CHECK-LABEL: TEST: testNestedSequenceOp | ||
# CHECK: transform.sequence failures(propagate) { | ||
# CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op): | ||
# CHECK: sequence %[[ARG0]] : !transform.any_op failures(propagate) { | ||
# CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op): | ||
# CHECK: = sequence %[[ARG1]] : !transform.any_op -> !transform.any_op failures(propagate) { | ||
# CHECK: ^{{.*}}(%[[ARG2:.+]]: !transform.any_op): | ||
# CHECK: yield %[[ARG2]] : !transform.any_op | ||
# CHECK: } | ||
# CHECK: } | ||
# CHECK: } | ||
|
||
|
||
@run | ||
def testSequenceOpWithExtras(): | ||
sequence = transform.SequenceOp( | ||
transform.FailurePropagationMode.Propagate, | ||
[], | ||
transform.AnyOpType.get(), | ||
[transform.AnyOpType.get(), transform.OperationType.get("foo.bar")], | ||
) | ||
with InsertionPoint(sequence.body): | ||
transform.YieldOp() | ||
# CHECK-LABEL: TEST: testSequenceOpWithExtras | ||
# CHECK: transform.sequence failures(propagate) | ||
# CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">): | ||
|
||
|
||
@run | ||
def testNestedSequenceOpWithExtras(): | ||
sequence = transform.SequenceOp( | ||
transform.FailurePropagationMode.Propagate, | ||
[], | ||
transform.AnyOpType.get(), | ||
[transform.AnyOpType.get(), transform.OperationType.get("foo.bar")], | ||
) | ||
with InsertionPoint(sequence.body): | ||
nested = transform.SequenceOp( | ||
transform.FailurePropagationMode.Propagate, | ||
[], | ||
sequence.bodyTarget, | ||
sequence.bodyExtraArgs, | ||
) | ||
with InsertionPoint(nested.body): | ||
transform.YieldOp() | ||
transform.YieldOp() | ||
# CHECK-LABEL: TEST: testNestedSequenceOpWithExtras | ||
# CHECK: transform.sequence failures(propagate) | ||
# CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">): | ||
# CHECK: sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">) | ||
|
||
|
||
@run | ||
def testTransformPDLOps(): | ||
withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get()) | ||
with InsertionPoint(withPdl.body): | ||
sequence = transform.SequenceOp( | ||
transform.FailurePropagationMode.Propagate, | ||
def testNamedSequenceOp(module: Module): | ||
module.operation.attributes["transform.with_named_sequence"] = UnitAttr.get() | ||
named_sequence = transform.NamedSequenceOp( | ||
'__transform_main', | ||
[transform.AnyOpType.get()], | ||
[transform.AnyOpType.get()], | ||
withPdl.bodyTarget, | ||
) | ||
with InsertionPoint(sequence.body): | ||
match = transform_pdl.PDLMatchOp( | ||
transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher" | ||
) | ||
transform.YieldOp(match) | ||
# CHECK-LABEL: TEST: testTransformPDLOps | ||
# CHECK: transform.with_pdl_patterns { | ||
# CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op): | ||
# CHECK: = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) { | ||
# CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op): | ||
# CHECK: %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]] | ||
# CHECK: yield %[[RES]] : !transform.any_op | ||
# CHECK: } | ||
# CHECK: } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are all these deleted? These are tests for Python API of ops that still exist, and these ops are not in any way affected by this change. Please revert deletion!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would have requested changes here had I been given the chance. This shouldn't just decrease test coverage without justification.
sorry, bad debug state landed, on it |
…llvm#71597)" This reverts commit 4f51b2b.
Address issues with #71597 post-revert and and reland
No description provided.