Skip to content

Commit

Permalink
[TIR][Schedule][UX] Beautify TIR Trace Printing (#12507)
Browse files Browse the repository at this point in the history
Following #12197, this PR introduces
`Schedule.show()` which convenience the user experience in the following
two aspects:
- Python syntax highlighting
- Outputs a schedule function instead of standalone instructions so that
it's easier to follow.

To demonstrate this change:
- Before `Schedule.show()` is introduced:
<img width="555" alt="image" src="https://user-images.githubusercontent.com/22515877/185713487-03722566-1df7-45c7-a034-c1460d399681.png">

- After this change:
<img width="583" alt="image" src="https://user-images.githubusercontent.com/22515877/185713564-c54f3a9d-cd52-4709-a8b8-d8a61361e611.png">
  • Loading branch information
junrushao authored Aug 20, 2022
1 parent eb31123 commit 3b3443b
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 56 deletions.
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/testing/space_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def check_trace(spaces: List[Schedule], expected: List[List[str]]):
for space in spaces:
trace = Trace(space.trace.insts, {})
trace = trace.simplified(remove_postproc=True)
str_trace = "\n".join(str(trace).strip().splitlines())
str_trace = "\n".join(t[2:] for t in str(trace).strip().splitlines()[2:] if t != " pass")
actual_traces.add(str_trace)
assert str_trace in expected_traces, "\n" + str_trace
assert len(expected_traces) == len(actual_traces)
Expand Down
28 changes: 16 additions & 12 deletions python/tvm/script/highlight.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,20 @@
"""Highlight printed TVM script.
"""

from typing import Union, Optional
import warnings
import sys
import warnings
from typing import Optional, Union

from tvm.ir import IRModule
from tvm.tir import PrimFunc


def cprint(printable: Union[IRModule, PrimFunc], style: Optional[str] = None) -> None:
def cprint(printable: Union[IRModule, PrimFunc, str], style: Optional[str] = None) -> None:
"""
Print highlighted TVM script string with Pygments
Parameters
----------
printable : Union[IRModule, PrimFunc]
printable : Union[IRModule, PrimFunc, str]
The TVM script to be printed
style : str, optional
Printing style, auto-detected if None.
Expand All @@ -44,16 +44,17 @@ def cprint(printable: Union[IRModule, PrimFunc], style: Optional[str] = None) ->
installing the Pygment library. Other Pygment styles can be found in
https://pygments.org/styles/
"""

if isinstance(printable, (IRModule, PrimFunc)):
printable = printable.script()
try:
# pylint: disable=import-outside-toplevel
import pygments
from packaging import version
from pygments import highlight
from pygments.formatters import HtmlFormatter, Terminal256Formatter
from pygments.lexers.python import Python3Lexer
from pygments.formatters import Terminal256Formatter, HtmlFormatter
from pygments.style import Style
from pygments.token import Keyword, Name, Comment, String, Number, Operator
from packaging import version
from pygments.token import Comment, Keyword, Name, Number, Operator, String

if version.parse(pygments.__version__) < version.parse("2.4.0"):
raise ImportError("Required Pygments version >= 2.4.0 but got " + pygments.__version__)
Expand All @@ -68,7 +69,7 @@ def cprint(printable: Union[IRModule, PrimFunc], style: Optional[str] = None) ->
+ install_cmd,
category=UserWarning,
)
print(printable.script())
print(printable)
else:

class JupyterLight(Style):
Expand Down Expand Up @@ -136,11 +137,14 @@ class AnsiTerminalDefault(Style):
style = AnsiTerminalDefault

if is_in_notebook: # print with HTML display
from IPython.display import display, HTML # pylint: disable=import-outside-toplevel
from IPython.display import ( # pylint: disable=import-outside-toplevel
HTML,
display,
)

formatter = HtmlFormatter(style=JupyterLight)
formatter.noclasses = True # inline styles
html = highlight(printable.script(), Python3Lexer(), formatter)
html = highlight(printable, Python3Lexer(), formatter)
display(HTML(html))
else:
print(highlight(printable.script(), Python3Lexer(), Terminal256Formatter(style=style)))
print(highlight(printable, Python3Lexer(), Terminal256Formatter(style=style)))
14 changes: 14 additions & 0 deletions python/tvm/tir/schedule/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,3 +258,17 @@ def apply_json_to_schedule(json_obj: JSON_TYPE, sch: "Schedule") -> None:
The TensorIR schedule
"""
_ffi_api.TraceApplyJSONToSchedule(json_obj, sch) # type: ignore # pylint: disable=no-member

def show(self, style: Optional[str] = None) -> None:
"""A sugar for print highlighted trace.
Parameters
----------
style : str, optional
Pygments styles extended by "light" (default) and "dark", by default "light"
"""
from tvm.script.highlight import ( # pylint: disable=import-outside-toplevel
cprint,
)

cprint(str(self), style=style)
10 changes: 8 additions & 2 deletions src/tir/schedule/trace.cc
Original file line number Diff line number Diff line change
Expand Up @@ -476,16 +476,22 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TraceNode>([](const ObjectRef& obj, ReprPrinter* p) {
const auto* self = obj.as<TraceNode>();
ICHECK_NOTNULL(self);
p->stream << "# from tvm import tir\n";
p->stream << "def apply_trace(sch: tir.Schedule) -> None:\n";
Array<String> repr = self->AsPython(/*remove_postproc=*/false);
bool is_first = true;
for (const String& line : repr) {
if (is_first) {
is_first = false;
} else {
p->stream << std::endl;
p->stream << '\n';
}
p->stream << line;
p->stream << " " << line;
}
if (is_first) {
p->stream << " pass";
}
p->stream << std::flush;
});

/**************** Instruction Registration ****************/
Expand Down
28 changes: 16 additions & 12 deletions tests/python/unittest/test_meta_schedule_post_order_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,18 +322,22 @@ def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]:
def correct_trace(a, b, c, d):
return "\n".join(
[
'b0 = sch.get_block(name="A", func_name="main")',
'b1 = sch.get_block(name="B", func_name="main")',
'b2 = sch.get_block(name="C", func_name="main")',
"sch.compute_inline(block=b1)",
"l3, l4 = sch.get_loops(block=b2)",
"l5, l6 = sch.split(loop=l3, factors=" + str(a) + ", preserve_unit_iters=True)",
"l7, l8 = sch.split(loop=l4, factors=" + str(b) + ", preserve_unit_iters=True)",
"sch.reorder(l5, l7, l6, l8)",
"l9, l10 = sch.get_loops(block=b0)",
"l11, l12 = sch.split(loop=l9, factors=" + str(c) + ", preserve_unit_iters=True)",
"l13, l14 = sch.split(loop=l10, factors=" + str(d) + ", preserve_unit_iters=True)",
"sch.reorder(l11, l13, l12, l14)",
"# from tvm import tir",
"def apply_trace(sch: tir.Schedule) -> None:",
' b0 = sch.get_block(name="A", func_name="main")',
' b1 = sch.get_block(name="B", func_name="main")',
' b2 = sch.get_block(name="C", func_name="main")',
" sch.compute_inline(block=b1)",
" l3, l4 = sch.get_loops(block=b2)",
" l5, l6 = sch.split(loop=l3, factors=" + str(a) + ", preserve_unit_iters=True)",
" l7, l8 = sch.split(loop=l4, factors=" + str(b) + ", preserve_unit_iters=True)",
" sch.reorder(l5, l7, l6, l8)",
" l9, l10 = sch.get_loops(block=b0)",
" l11, l12 = sch.split(loop=l9, factors=" + str(c) + ", preserve_unit_iters=True)",
" l13, l14 = sch.split(loop=l10, factors="
+ str(d)
+ ", preserve_unit_iters=True)",
" sch.reorder(l11, l13, l12, l14)",
]
)

Expand Down
92 changes: 63 additions & 29 deletions tests/python/unittest/test_tir_schedule_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,10 @@ def test_trace_construct_1():
trace = _make_trace_1(BlockRV(), LoopRV(), LoopRV())
assert str(trace) == "\n".join(
(
'b0 = sch.get_block(name="block", func_name="main")',
"l1, l2 = sch.get_loops(block=b0)",
"# from tvm import tir",
"def apply_trace(sch: tir.Schedule) -> None:",
' b0 = sch.get_block(name="block", func_name="main")',
" l1, l2 = sch.get_loops(block=b0)",
)
)
assert len(trace.insts) == 2
Expand All @@ -182,9 +184,11 @@ def test_trace_construct_append_1():
trace.append(inst=_make_get_block("block2", BlockRV()))
assert str(trace) == "\n".join(
(
'b0 = sch.get_block(name="block", func_name="main")',
"l1, l2 = sch.get_loops(block=b0)",
'b3 = sch.get_block(name="block2", func_name="main")',
"# from tvm import tir",
"def apply_trace(sch: tir.Schedule) -> None:",
' b0 = sch.get_block(name="block", func_name="main")',
" l1, l2 = sch.get_loops(block=b0)",
' b3 = sch.get_block(name="block2", func_name="main")',
)
)

Expand All @@ -193,14 +197,32 @@ def test_trace_construct_pop_1():
trace = _make_trace_1(BlockRV(), LoopRV(), LoopRV())
last_inst = trace.insts[-1]
assert trace.pop().same_as(last_inst)
assert str(trace) == 'b0 = sch.get_block(name="block", func_name="main")'
assert str(trace) == "\n".join(
(
"# from tvm import tir",
"def apply_trace(sch: tir.Schedule) -> None:",
' b0 = sch.get_block(name="block", func_name="main")',
)
)


def test_trace_construct_pop_2():
trace = Trace([], {})
assert str(trace) == ""
assert str(trace) == "\n".join(
(
"# from tvm import tir",
"def apply_trace(sch: tir.Schedule) -> None:",
" pass",
)
)
assert trace.pop() is None
assert str(trace) == ""
assert str(trace) == "\n".join(
(
"# from tvm import tir",
"def apply_trace(sch: tir.Schedule) -> None:",
" pass",
)
)


def test_trace_apply_to_schedule():
Expand All @@ -226,18 +248,22 @@ def test_trace_simplified_1():
trace = _make_trace_3(BlockRV(), BlockRV(), add_postproc=True)
assert str(trace) == "\n".join(
(
'b0 = sch.get_block(name="B", func_name="main")',
"sch.compute_inline(block=b0)",
'b1 = sch.get_block(name="C", func_name="main")',
"sch.enter_postproc()",
"sch.compute_inline(block=b1)",
"# from tvm import tir",
"def apply_trace(sch: tir.Schedule) -> None:",
' b0 = sch.get_block(name="B", func_name="main")',
" sch.compute_inline(block=b0)",
' b1 = sch.get_block(name="C", func_name="main")',
" sch.enter_postproc()",
" sch.compute_inline(block=b1)",
)
)
trace = trace.simplified(remove_postproc=True)
assert str(trace) == "\n".join(
(
'b0 = sch.get_block(name="B", func_name="main")',
"sch.compute_inline(block=b0)",
"# from tvm import tir",
"def apply_trace(sch: tir.Schedule) -> None:",
' b0 = sch.get_block(name="B", func_name="main")',
" sch.compute_inline(block=b0)",
)
)

Expand All @@ -246,21 +272,26 @@ def test_trace_simplified_2():
trace = _make_trace_3(BlockRV(), BlockRV(), add_postproc=True)
assert str(trace) == "\n".join(
(
'b0 = sch.get_block(name="B", func_name="main")',
"sch.compute_inline(block=b0)",
'b1 = sch.get_block(name="C", func_name="main")',
"sch.enter_postproc()",
"sch.compute_inline(block=b1)",
"# from tvm import tir",
"def apply_trace(sch: tir.Schedule) -> None:",
' b0 = sch.get_block(name="B", func_name="main")',
" sch.compute_inline(block=b0)",
' b1 = sch.get_block(name="C", func_name="main")',
" sch.enter_postproc()",
" sch.compute_inline(block=b1)",
)
)
trace = trace.simplified(remove_postproc=False)
print(trace.show())
assert str(trace) == "\n".join(
(
'b0 = sch.get_block(name="B", func_name="main")',
"sch.compute_inline(block=b0)",
'b1 = sch.get_block(name="C", func_name="main")',
"sch.enter_postproc()",
"sch.compute_inline(block=b1)",
"# from tvm import tir",
"def apply_trace(sch: tir.Schedule) -> None:",
' b0 = sch.get_block(name="B", func_name="main")',
" sch.compute_inline(block=b0)",
' b1 = sch.get_block(name="C", func_name="main")',
" sch.enter_postproc()",
" sch.compute_inline(block=b1)",
)
)

Expand All @@ -269,9 +300,11 @@ def test_trace_simplified_3():
trace = _make_trace_4(BlockRV(), LoopRV(), LoopRV(), LoopRV()).simplified(remove_postproc=False)
assert str(trace) == "\n".join(
(
'b0 = sch.get_block(name="B", func_name="main")',
"l1, = sch.get_loops(block=b0)",
"l2, l3 = sch.split(loop=l1, factors=[None, 32], preserve_unit_iters=True)",
"# from tvm import tir",
"def apply_trace(sch: tir.Schedule) -> None:",
' b0 = sch.get_block(name="B", func_name="main")',
" l1, = sch.get_loops(block=b0)",
" l2, l3 = sch.split(loop=l1, factors=[None, 32], preserve_unit_iters=True)",
)
)

Expand Down Expand Up @@ -335,4 +368,5 @@ def test_apply_annotation_from_json():


if __name__ == "__main__":
tvm.testing.main()
test_trace_simplified_2()
# tvm.testing.main()

0 comments on commit 3b3443b

Please sign in to comment.