Skip to content

Commit

Permalink
Revert "[mlir][python]Add sugared buider for transform.named_sequence (
Browse files Browse the repository at this point in the history
…llvm#71597)"

This reverts commit 4f51b2b.
  • Loading branch information
nicolasvasilache committed Nov 8, 2023
1 parent 4c9f7b6 commit be056f6
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 47 deletions.
28 changes: 0 additions & 28 deletions mlir/python/mlir/dialects/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,34 +165,6 @@ def bodyExtraArgs(self) -> BlockArgumentList:
return self.body.arguments[1:]


@_ods_cext.register_operation(_Dialect, replace=True)
class NamedSequenceOp(NamedSequenceOp):
def __init__(
self,
sym_name,
input_types: Sequence[Type],
result_types: Sequence[Type],
):
function_type = FunctionType.get(input_types, result_types)
super().__init__(
sym_name=sym_name,
function_type=TypeAttr.get(function_type),
)
self.regions[0].blocks.append(*input_types)

@property
def body(self) -> Block:
return self.regions[0].blocks[0]

@property
def bodyTarget(self) -> Value:
return self.body.arguments[0]

@property
def bodyExtraArgs(self) -> BlockArgumentList:
return self.body.arguments[1:]


@_ods_cext.register_operation(_Dialect, replace=True)
class YieldOp(YieldOp):
def __init__(
Expand Down
118 changes: 99 additions & 19 deletions mlir/test/python/dialects/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ def run(f):
module = Module.create()
with InsertionPoint(module.body):
print("\nTEST:", f.__name__)
f(module)
f()
print(module)
return f


@run
def testTypes(module: Module):
def testTypes():
# CHECK-LABEL: TEST: testTypes
# CHECK: !transform.any_op
any_op = transform.AnyOpType.get()
Expand Down Expand Up @@ -44,7 +44,7 @@ def testTypes(module: Module):


@run
def testSequenceOp(module: Module):
def testSequenceOp():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[transform.AnyOpType.get()],
Expand All @@ -60,23 +60,103 @@ def testSequenceOp(module: Module):


@run
def testNamedSequenceOp(module: Module):
module.operation.attributes["transform.with_named_sequence"] = UnitAttr.get()
named_sequence = transform.NamedSequenceOp(
'__transform_main',
[transform.AnyOpType.get()],
def testNestedSequenceOp():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
with InsertionPoint(sequence.body):
nested = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], sequence.bodyTarget
)
with InsertionPoint(nested.body):
doubly_nested = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[transform.AnyOpType.get()],
nested.bodyTarget,
)
with InsertionPoint(doubly_nested.body):
transform.YieldOp([doubly_nested.bodyTarget])
transform.YieldOp()
transform.YieldOp()
# CHECK-LABEL: TEST: testNestedSequenceOp
# CHECK: transform.sequence failures(propagate) {
# CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
# CHECK: sequence %[[ARG0]] : !transform.any_op failures(propagate) {
# CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
# CHECK: = sequence %[[ARG1]] : !transform.any_op -> !transform.any_op failures(propagate) {
# CHECK: ^{{.*}}(%[[ARG2:.+]]: !transform.any_op):
# CHECK: yield %[[ARG2]] : !transform.any_op
# CHECK: }
# CHECK: }
# CHECK: }


@run
def testSequenceOpWithExtras():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.AnyOpType.get(),
[transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
)
with InsertionPoint(sequence.body):
transform.YieldOp()
# CHECK-LABEL: TEST: testSequenceOpWithExtras
# CHECK: transform.sequence failures(propagate)
# CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">):


@run
def testNestedSequenceOpWithExtras():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.AnyOpType.get(),
[transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
)
with InsertionPoint(sequence.body):
nested = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
sequence.bodyTarget,
sequence.bodyExtraArgs,
)
with InsertionPoint(nested.body):
transform.YieldOp()
transform.YieldOp()
# CHECK-LABEL: TEST: testNestedSequenceOpWithExtras
# CHECK: transform.sequence failures(propagate)
# CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">):
# CHECK: sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">)


@run
def testTransformPDLOps():
withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
with InsertionPoint(withPdl.body):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[transform.AnyOpType.get()],
withPdl.bodyTarget,
)
with InsertionPoint(named_sequence.body):
transform.YieldOp([named_sequence.bodyTarget])
# CHECK-LABEL: TEST: testNamedSequenceOp
# CHECK: module attributes {transform.with_named_sequence} {
# CHECK: transform.named_sequence @__transform_main(%[[ARG0:.+]]: !transform.any_op) -> !transform.any_op {
# CHECK: yield %[[ARG0]] : !transform.any_op
with InsertionPoint(sequence.body):
match = transform_pdl.PDLMatchOp(
transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher"
)
transform.YieldOp(match)
# CHECK-LABEL: TEST: testTransformPDLOps
# CHECK: transform.with_pdl_patterns {
# CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
# CHECK: = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) {
# CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
# CHECK: %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]]
# CHECK: yield %[[RES]] : !transform.any_op
# CHECK: }
# CHECK: }


@run
def testGetParentOp(module: Module):
def testGetParentOp():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
Expand All @@ -95,7 +175,7 @@ def testGetParentOp(module: Module):


@run
def testMergeHandlesOp(module: Module):
def testMergeHandlesOp():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
Expand All @@ -109,7 +189,7 @@ def testMergeHandlesOp(module: Module):


@run
def testApplyPatternsOpCompact(module: Module):
def testApplyPatternsOpCompact():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
Expand All @@ -124,7 +204,7 @@ def testApplyPatternsOpCompact(module: Module):


@run
def testApplyPatternsOpWithType(module: Module):
def testApplyPatternsOpWithType():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [],
transform.OperationType.get('test.dummy')
Expand All @@ -140,7 +220,7 @@ def testApplyPatternsOpWithType(module: Module):


@run
def testReplicateOp(module: Module):
def testReplicateOp():
with_pdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
with InsertionPoint(with_pdl.body):
sequence = transform.SequenceOp(
Expand Down

0 comments on commit be056f6

Please sign in to comment.