diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py index 1e532367a321..cf21ea950549 100644 --- a/python/tvm/hybrid/parser.py +++ b/python/tvm/hybrid/parser.py @@ -3,7 +3,7 @@ import ast import operator import sys -from .util import make_nop, halide_imm_types +from .util import make_nop, halide_imm_types, is_docstring from .intrin import LOOP_INTRIN, MATH_INTRIN from .var_decl import determine_variable_usage from ..api import thread_axis @@ -15,7 +15,7 @@ def list_to_block(visit, lst): """Convert a list of Python IR nodes to HalideIR Block""" - lst = list(map(visit, lst)) + lst = [visit(stmt) for stmt in lst if not is_docstring(stmt)] lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, make_nop())] if not lst: return make_nop() diff --git a/python/tvm/hybrid/util.py b/python/tvm/hybrid/util.py index 43d26e859560..2a43957e9706 100644 --- a/python/tvm/hybrid/util.py +++ b/python/tvm/hybrid/util.py @@ -1,5 +1,6 @@ """Internal utilities for parsing Python subset to HalideIR""" +import ast import inspect import numpy from .intrin import HYBRID_GLOBALS @@ -22,6 +23,11 @@ def make_nop(): return _make.Evaluate(_api.const(0, dtype='int32')) +def is_docstring(node): + """Checks if a Python AST node is a docstring""" + return isinstance(node, ast.Expr) and isinstance(node.value, ast.Str) + + def _pruned_source(func): """Prune source code's extra leading spaces""" lines = inspect.getsource(func).split('\n') diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py index 0f500d7c704f..ef0bcf8f72e5 100644 --- a/tests/python/unittest/test_hybrid_script.py +++ b/tests/python/unittest/test_hybrid_script.py @@ -43,6 +43,7 @@ def tvm_val_2_py_val(val): @script def outer_product(n, m, a, b, c): + """This is a simple outer product""" for i in range(n): for j in range(m): c[i, j] = a[i] * b[j]