Skip to content

Commit

Permalink
[UMA] pylint
Browse files Browse the repository at this point in the history
  • Loading branch information
MichaelJKlaiber committed Aug 2, 2022
1 parent 6428d35 commit 169f338
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
19 changes: 11 additions & 8 deletions python/tvm/relay/backend/contrib/uma/api/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,25 +94,28 @@ def partition(
mod["main"] = bind_params_by_name(mod["main"], params)

pass_sequence = []
pass_sequence.extend([p[1] for p in self._relay_passes if p[0] == PassPhase.PRE_PARTITIONING])
pass_sequence.extend(
[p[1] for p in self._relay_passes if p[0] == PassPhase.PRE_PARTITIONING]
)
pass_sequence.append(relay.transform.MergeComposite(self._pattern_table()))
pass_sequence.append(relay.transform.AnnotateTarget(self.target_name))
if self.merge_compiler_regions:
pass_sequence.append(relay.transform.MergeCompilerRegions())
pass_sequence.append(relay.transform.PartitionGraph())
pass_sequence.extend([p[1] for p in self._relay_passes if p[0] == PassPhase.POST_PARTITIONING_0])

pass_sequence.extend(
[p[1] for p in self._relay_passes if p[0] == PassPhase.POST_PARTITIONING_0]
)

sequential_passes = tvm.transform.Sequential(pass_sequence)
mod = sequential_passes(mod)


# Defunctionalize the partitioned functions to allow lowering
for gvar, func in mod.functions.items():
mod.update_func(gvar, relay.transform.Defunctionalization(func, mod))

post_partition_passes_1 = tvm.transform.Sequential([p[1] for p in self._relay_passes if p[0] == PassPhase.POST_PARTITIONING_1])
mod = post_partition_passes_1(mod)


post_partition_passes_1 = tvm.transform.Sequential(
[p[1] for p in self._relay_passes if p[0] == PassPhase.POST_PARTITIONING_1]
)
mod = post_partition_passes_1(mod)

return mod
3 changes: 0 additions & 3 deletions tests/python/contrib/test_uma/test_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,6 @@ def test_existing_pattern_tables(workload, backend, merge):
partitioner.register()
partitioned_mod = partitioner.partition(mod)



def partition_default(mod):
"""partitions using default BYOC flow"""

Expand All @@ -85,7 +83,6 @@ def partition_default(mod):
if merge:
sequence.append(relay.transform.MergeCompilerRegions())


sequence.append(relay.transform.PartitionGraph())
sequential = tvm.transform.Sequential(sequence)

Expand Down

0 comments on commit 169f338

Please sign in to comment.