Skip to content

Commit

Permalink
[PASS/SETUP] Fix minior issues (apache#663)
Browse files Browse the repository at this point in the history
* [PASS/SETUP] Fix minior issues

* fix lint
  • Loading branch information
tqchen authored Nov 21, 2017
1 parent 46e6cae commit 9c0da90
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 23 deletions.
20 changes: 14 additions & 6 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,21 @@
namespace tvm {
namespace ir {

inline Expr Simplify(Expr a) {
return Halide::Internal::simplify(a);
}
/*!
* \brief Simplify the expression.
* \param expr The expression to be simplifed.
* \param vrange The range information about the variable.
* \return Canonicalized statement.
*/
Expr Simplify(Expr expr, Map<Var, Range> vrange = Map<Var, Range>());

inline Stmt Simplify(Stmt a) {
return Halide::Internal::simplify(a);
}
/*!
* \brief Simplify the statement.
* \param stmt The statement to be simplifed.
* \param vrange The range information about the variable.
* \return Canonicalized statement.
*/
Stmt Simplify(Stmt stmt, Map<Var, Range> vrange = Map<Var, Range>());

/*!
* \brief Simplify by applying canonical form.
Expand Down
42 changes: 28 additions & 14 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,25 @@
from setuptools import setup
from setuptools.extension import Extension

# We can not import `libinfo.py` in setup.py directly since __init__.py
# Will be invoked which introduces dependences
CURRENT_DIR = os.path.dirname(__file__)
libinfo_py = os.path.join(CURRENT_DIR, './tvm/_ffi/libinfo.py')
libinfo = {'__file__': libinfo_py}
exec(compile(open(libinfo_py, "rb").read(), libinfo_py, 'exec'), libinfo, libinfo)
def get_lib_path():
"""Get library path, name and version"""
# We can not import `libinfo.py` in setup.py directly since __init__.py
# Will be invoked which introduces dependences
CURRENT_DIR = os.path.dirname(__file__)
libinfo_py = os.path.join(CURRENT_DIR, './tvm/_ffi/libinfo.py')
libinfo = {'__file__': libinfo_py}
exec(compile(open(libinfo_py, "rb").read(), libinfo_py, 'exec'), libinfo, libinfo)
lib_path = libinfo['find_lib_path']()
version = libinfo['__version__']
libs = [lib_path[0]]
if libs[0].find("runtime") == -1:
for name in lib_path[1:]:
if name.find("runtime") != -1:
libs.append(name)
break
return libs, version

LIB_PATH = libinfo['find_lib_path']()
_, LIB_NAME = os.path.split(LIB_PATH[0])
__version__ = libinfo['__version__']
LIB_LIST, __version__ = get_lib_path()

def config_cython():
"""Try to configure cython and return cython configuration"""
Expand Down Expand Up @@ -81,18 +90,21 @@ def is_pure(self):

# For bdist_wheel only
if "bdist_wheel" in sys.argv:
shutil.copy(LIB_PATH[0], os.path.join(CURRENT_DIR, 'tvm'))
with open("MANIFEST.in", "w") as fo:
fo.write("include tvm/%s\n" % LIB_NAME)
for path in LIB_LIST:
shutil.copy(path, os.path.join(CURRENT_DIR, 'tvm'))
_, libname = os.path.split(path)
fo.write("include tvm/%s\n" % libname)
setup_kwargs = {
"include_package_data": True
}
else:
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
rpath = os.path.relpath(LIB_PATH[0], curr_path)
for i, path in enumerate(LIB_LIST):
LIB_LIST[i] = os.path.relpath(path, curr_path)
setup_kwargs = {
"include_package_data": True,
"data_files": [('tvm', [rpath])]
"data_files": [('tvm', LIB_LIST)]
}

setup(name='tvm',
Expand All @@ -112,4 +124,6 @@ def is_pure(self):
# Wheel cleanup
if "bdist_wheel" in sys.argv:
os.remove("MANIFEST.in")
os.remove("tvm/%s" % LIB_NAME)
for path in LIB_LIST:
_, libname = os.path.split(path)
os.remove("tvm/%s" % LIB_NAME)
3 changes: 2 additions & 1 deletion python/tvm/_ffi/libinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def find_lib_path(name=None, search_path=None):
if not use_runtime:
# try to find lib_dll_path
lib_found = [p for p in lib_dll_path if os.path.exists(p) and os.path.isfile(p)]
if use_runtime or not lib_found:
lib_found += [p for p in runtime_dll_path if os.path.exists(p) and os.path.isfile(p)]
else:
# try to find runtime_dll_path
use_runtime = True
lib_found = [p for p in runtime_dll_path if os.path.exists(p) and os.path.isfile(p)]
Expand Down
12 changes: 10 additions & 2 deletions src/api/api_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,17 @@ namespace ir {
TVM_REGISTER_API("ir_pass.Simplify")
.set_body([](TVMArgs args, TVMRetValue *ret) {
if (args[0].IsNodeType<Stmt>()) {
*ret = Simplify(args[0].operator Stmt());
if (args.size() > 1) {
*ret = Simplify(args[0].operator Stmt(), args[1]);
} else {
*ret = Simplify(args[0].operator Stmt());
}
} else {
*ret = Simplify(args[0].operator Expr());
if (args.size() > 1) {
*ret = Simplify(args[0].operator Expr(), args[1]);
} else {
*ret = Simplify(args[0].operator Expr());
}
}
});

Expand Down
24 changes: 24 additions & 0 deletions src/arithmetic/canonical.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <tvm/arithmetic.h>
#include "./canonical.h"
#include "./compute_expr.h"
#include "arithmetic/Simplify.h"

namespace tvm {
namespace arith {
Expand Down Expand Up @@ -559,5 +560,28 @@ Stmt CanonicalSimplify(Stmt stmt) {
Expr CanonicalSimplify(Expr expr) {
return arith::Canonical().Simplify(expr);
}

template<typename T>
T Simplify_(T a, Map<Var, Range> vrange) {
using namespace Halide::Internal;
Scope<Interval> rscope;
for (auto kv : vrange) {
Range r = kv.second;
rscope.push(
kv.first.get(),
Interval(r->min,
simplify(r->min + r->extent - make_const(r->min.type(), 1))));
}
return Halide::Internal::simplify(a, true, rscope);
}


Expr Simplify(Expr a, Map<Var, Range> vrange) {
return Simplify_(a, vrange);
}

Stmt Simplify(Stmt a, Map<Var, Range> vrange) {
return Simplify_(a, vrange);
}
} // namespace ir
} // namespace tvm
8 changes: 8 additions & 0 deletions tests/python/unittest/test_pass_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ def test_basic():
assert str(ret.value) == "(m - 1)"


def test_bound():
m = tvm.var('m')
vrange = tvm.convert({m: tvm.Range(tvm.const(0), tvm.const(10))})
ret = tvm.ir_pass.Simplify(m % 10, vrange)
assert ret == m


def test_canonical():
x = tvm.var("x")
z = tvm.const(3)
Expand All @@ -37,6 +44,7 @@ def test_canonical():
assert(tvm.ir_pass.Equal(ret, 0))

if __name__ == "__main__":
test_bound()
test_basic()
test_simplify()
test_canonical()

0 comments on commit 9c0da90

Please sign in to comment.