Skip to content

Commit

Permalink
[HybridScript] Capture constant external python variables (apache#3157)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored and wweic committed May 13, 2019
1 parent 21f6564 commit 3982ef4
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 21 deletions.
6 changes: 5 additions & 1 deletion python/tvm/hybrid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@

from __future__ import absolute_import as _abs

import inspect

from .._ffi.base import decorate
from .._ffi.function import _init_api
from ..build_module import form_body
Expand All @@ -55,7 +57,9 @@ def wrapped_func(func, *args, **kwargs): #pylint: disable=missing-docstring
from .util import _is_tvm_arg_types
if _is_tvm_arg_types(args):
src = _pruned_source(func)
return source_to_op(src, func.__globals__, args)
closure_vars = inspect.getclosurevars(func).nonlocals
closure_vars.update(inspect.getclosurevars(func).globals)
return source_to_op(src, args, func.__globals__, closure_vars)

from .runtime import _enter_hybrid_runtime, _restore_runtime
intersect = _enter_hybrid_runtime(func)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/hybrid/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(self, src=None, name=None):

def __call__(self, *args):
if _is_tvm_arg_types(args):
return source_to_op(self.root_, globals(), args)
return source_to_op(self.root_, args, globals(), {})
return self.func_(*args)


Expand Down
50 changes: 35 additions & 15 deletions python/tvm/hybrid/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from enum import Enum

from .util import _internal_assert
from .util import _internal_assert, _apply_indices
from . import calls
from . import util
from .preprocessor import determine_variable_usage
Expand Down Expand Up @@ -112,7 +112,7 @@ class HybridParser(ast.NodeVisitor):
}


def __init__(self, args, usage, symbols, func_name=None):
def __init__(self, args, usage, symbols, closure_vars, func_name=None):
"""
Parameters
----------
Expand All @@ -122,6 +122,12 @@ def __init__(self, args, usage, symbols, func_name=None):
usage: A dict of variables used in last in this function
Provided by last lower pass, which collects this information
symbols : list of str
The symbol list of the global context of the function.
closure_vars: dict
A dict of external name reference captured by this function.
Returns
-------
func_name: str
Expand All @@ -136,6 +142,8 @@ def __init__(self, args, usage, symbols, func_name=None):
if isinstance(v, types.FunctionType):
self.add_symbol(k, Symbol.Callable, v)

self.closure_vars = closure_vars

self.binds = {} # Thread binds
self.device = 0 # Is it generating device

Expand Down Expand Up @@ -236,7 +244,11 @@ def visit_Expr(self, node):
def visit_Name(self, node):
name = node.id
if sys.version_info[0] == 2 and name in ['True', 'False']:
return _api.convert(eval(name)) #pylint: disable=eval-used
return _api.convert(ast.literal_eval(name))

if name in self.closure_vars:
return _api.convert(self.closure_vars[name])

ty, entry = self.symbols[name]
_internal_assert(name in self.symbols, "Unknown symbol %s!" % name)
if ty in [Symbol.LoopVar, Symbol.Input, Symbol.ConstLoopVar]:
Expand Down Expand Up @@ -356,10 +368,12 @@ def visit_Attribute(self, node):
buf = self.visit(node.value)
return getattr(buf, node.attr)


def visit_Subscript(self, node):
args = self.visit(node.slice)
if isinstance(node.value, ast.Name):
if node.value.id in self.closure_vars:
args = ast.literal_eval(str(args))
return _api.convert(_apply_indices(self.closure_vars[node.value.id], args))

buf = self.visit(node.value)
if isinstance(buf, Array):
Expand Down Expand Up @@ -576,7 +590,7 @@ def visit_Assert(self, node):
return _make.AssertStmt(test, mesg, util.make_nop())


def parse_python(src, symbols, args):
def parse_python(src, args, symbols, closure_vars):
"""The helper function of calling the AST visitor
Parameters
Expand All @@ -585,29 +599,32 @@ def parse_python(src, symbols, args):
If an ast.node, then directly lower it.
If a str, then parse it to ast and lower it.
symbols : str
The symbol list of the global context of the function.
args : list of Tensors or Vars
The argument lists to the function.
It is NOT encouraged to write a function without arguments.
It is NOT encouraged to write a function with side effect.
symbols : list of str
The symbol list of the global context of the function.
closure_vars: dict
A dict of external name reference captured by this function.
Returns
-------
root : Stmt
The result Halide IR and the parser class instance.
"""
root = ast.parse(src) if isinstance(src, str) else src
_internal_assert(root, ast.AST)
var_usage = determine_variable_usage(root, args, symbols)
parser = HybridParser(args, var_usage, symbols)
var_usage = determine_variable_usage(root, args, symbols, closure_vars)
parser = HybridParser(args, var_usage, symbols, closure_vars)
parser.parsed_body = parser.visit(root)
_internal_assert(parser.returned, 'No valid return found in the function body!')
return parser


def source_to_op(src, symbols, args):
def source_to_op(src, args, symbols, closure_vars):
"""Another level of wrapper
Parameters
Expand All @@ -616,20 +633,23 @@ def source_to_op(src, symbols, args):
If an ast.node, then directly lower it.
If a str, then parse it to ast and lower it.
symbols : str
The symbol list of the global context of the function.
args : list of Tensors or Vars
The argument lists to the function.
It is NOT encouraged to write a function without arguments.
It is NOT encouraged to write a function with side effect.
symbols : list of str
The symbol list of the global context of the function.
closure_vars: dict
A dict of external name reference captured by this function.
Returns
-------
res : list of output tensors
The result of output tensors of the formed OpNode.
"""
parser = parse_python(src, symbols, args)
parser = parse_python(src, args, symbols, closure_vars)

input_tensors = []
for i in args:
Expand Down
16 changes: 12 additions & 4 deletions python/tvm/hybrid/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ class PyVariableUsage(ast.NodeVisitor):
"""The vistor class to determine the declaration, r/w status, and last use of each variable"""
#pylint: disable=invalid-name
#pylint: disable=missing-docstring
def __init__(self, args, symbols):
def __init__(self, args, symbols, closure_vars):
self.status = {}
self.scope_level = []
self._args = {}
self.args = args
self.aug_assign_ = False
self.symbols = symbols

self.closure_vars = closure_vars

def visit_FunctionDef(self, node):
self.scope_level.append(node)
Expand Down Expand Up @@ -89,6 +89,14 @@ def visit_Name(self, node):
"Iter var cannot be overwritten")

if node.id not in self.status.keys():
# It is a captured value in closure
if node.id in self.closure_vars:
try:
ast.literal_eval(str(self.closure_vars[node.id]))
except ValueError:
raise ValueError("Only support capturing constant values in closure")
return

_internal_assert(isinstance(node.ctx, ast.Store), \
'Undeclared variable %s' % node.id)
if self.aug_assign_:
Expand All @@ -102,8 +110,8 @@ def visit_Name(self, node):
self.status[node.id] = (decl, loop, usage)


def determine_variable_usage(root, args, symbols):
def determine_variable_usage(root, args, symbols, closure_vars):
"""The helper function for calling the dedicated visitor."""
visitor = PyVariableUsage(args, symbols)
visitor = PyVariableUsage(args, symbols, closure_vars)
visitor.visit(root)
return visitor.status
6 changes: 6 additions & 0 deletions python/tvm/hybrid/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,9 @@ def _is_tvm_arg_types(args):
_internal_assert(isinstance(elem, np_arg_types), \
"Expect a numpy type but %s get!" % str(type(elem)))
return False

def _apply_indices(value, indices):
"""Apply multidimensional index"""
if indices:
return _apply_indices(value[indices[0]], indices[1:])
return value
19 changes: 19 additions & 0 deletions tests/python/unittest/test_hybrid_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,24 @@ def outer_product(a, b):

# Test loop binds

def test_capture():
n = 8

constant_tuple = (10, n)
constant_list = [[1, 2], [3, n]]
const_value = 1

@tvm.hybrid.script
def add_something(a):
c = output_tensor((constant_tuple[1],), 'int32')
for i in range(constant_tuple[1]):
c[i] = a[i] + constant_list[1][const_value]
return c

a = tvm.placeholder((n, ), dtype='int32', name='a')

func, ins, outs = run_and_check(add_something, [a])
run_and_check(func, ins, outs=outs)

if __name__ == "__main__":
test_outer_product()
Expand All @@ -786,5 +804,6 @@ def outer_product(a, b):
test_bool()
test_const_range()
test_schedule()
test_capture()
# TODO:
# test_inplace()

0 comments on commit 3982ef4

Please sign in to comment.