From 8db682f00d3f8c0162e1013ed906f750252e906f Mon Sep 17 00:00:00 2001 From: Yuanming Hu Date: Mon, 23 Mar 2020 15:30:47 -0400 Subject: [PATCH 01/15] New extension 'adstack' --- python/taichi/lang/__init__.py | 1 + taichi/inc/extensions.inc.h | 1 + taichi/llvm/llvm_context.cpp | 2 ++ taichi/program/extension.cpp | 6 +++--- tests/python/test_ad_if.py | 5 +++++ 5 files changed, 12 insertions(+), 3 deletions(-) diff --git a/python/taichi/lang/__init__.py b/python/taichi/lang/__init__.py index ce4702c283b5c..681a2650252a8 100644 --- a/python/taichi/lang/__init__.py +++ b/python/taichi/lang/__init__.py @@ -41,6 +41,7 @@ class _Extension(object): def __init__(self): self.sparse = core.sparse self.data64 = core.data64 + self.adstack = core.adstack extension = _Extension() diff --git a/taichi/inc/extensions.inc.h b/taichi/inc/extensions.inc.h index a2aa2b64b023c..8083e19049103 100644 --- a/taichi/inc/extensions.inc.h +++ b/taichi/inc/extensions.inc.h @@ -1,3 +1,4 @@ // Lists of extension features PER_EXTENSION(sparse) PER_EXTENSION(data64) // Metal doesn't support 64-bit data buffers yet... +PER_EXTENSION(adstack) // For keeping the history of mutable local variables diff --git a/taichi/llvm/llvm_context.cpp b/taichi/llvm/llvm_context.cpp index f7a798fd7c629..d5f446798b54f 100644 --- a/taichi/llvm/llvm_context.cpp +++ b/taichi/llvm/llvm_context.cpp @@ -433,7 +433,9 @@ llvm::Value *TaichiLLVMContext::get_constant(T t) { std::is_same_v) { return llvm::ConstantInt::get(*ctx, llvm::APInt(32, (uint64)t, true)); } else if (std::is_same_v || + std::is_same_v || std::is_same_v) { + static_assert(sizeof(std::size_t) == sizeof(uint64)); return llvm::ConstantInt::get(*ctx, llvm::APInt(64, (uint64)t, true)); } else { TI_NOT_IMPLEMENTED diff --git a/taichi/program/extension.cpp b/taichi/program/extension.cpp index c0a948c6fc80a..9d8f6519c8ab9 100644 --- a/taichi/program/extension.cpp +++ b/taichi/program/extension.cpp @@ -7,9 +7,9 @@ TLANG_NAMESPACE_BEGIN bool is_supported(Arch arch, Extension ext) { static std::unordered_map> arch2ext = { - {Arch::x64, {Extension::sparse, Extension::data64}}, - {Arch::arm64, {Extension::sparse, Extension::data64}}, - {Arch::cuda, {Extension::sparse, Extension::data64}}, + {Arch::x64, {Extension::sparse, Extension::data64, Extension::adstack}}, + {Arch::arm64, {Extension::sparse, Extension::data64, Extension::adstack}}, + {Arch::cuda, {Extension::sparse, Extension::data64, Extension::adstack}}, {Arch::metal, {}}, {Arch::opengl, {}}, }; diff --git a/tests/python/test_ad_if.py b/tests/python/test_ad_if.py index 549e22f9277ae..83e9982ef1c39 100644 --- a/tests/python/test_ad_if.py +++ b/tests/python/test_ad_if.py @@ -1,6 +1,7 @@ import taichi as ti +@ti.require(ti.extension.adstack) @ti.all_archs def test_ad_if_simple(): x = ti.var(ti.f32, shape=()) @@ -22,6 +23,7 @@ def func(): assert x.grad[None] == 1 +@ti.require(ti.extension.adstack) @ti.all_archs def test_ad_if(): x = ti.var(ti.f32, shape=2) @@ -50,6 +52,7 @@ def func(i: ti.i32): assert x.grad[1] == 1 +@ti.require(ti.extension.adstack) @ti.all_archs def test_ad_if_mutable(): x = ti.var(ti.f32, shape=2) @@ -79,6 +82,7 @@ def func(i: ti.i32): assert x.grad[1] == 1 +@ti.require(ti.extension.adstack) @ti.all_archs def test_ad_if_prallel(): x = ti.var(ti.f32, shape=2) @@ -107,6 +111,7 @@ def func(): assert x.grad[1] == 1 +@ti.require(ti.extension.adstack) @ti.all_archs def test_ad_if_prallel_complex(): x = ti.var(ti.f32, shape=2) From 59354a4aa0b5e469ae5514aa802195af27693ed3 Mon Sep 17 00:00:00 2001 From: Yuanming Hu Date: Mon, 23 Mar 2020 16:48:18 -0400 Subject: [PATCH 02/15] [skip ci] enforce format --- taichi/inc/extensions.inc.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi/inc/extensions.inc.h b/taichi/inc/extensions.inc.h index 8083e19049103..6092d671252a0 100644 --- a/taichi/inc/extensions.inc.h +++ b/taichi/inc/extensions.inc.h @@ -1,4 +1,4 @@ // Lists of extension features PER_EXTENSION(sparse) -PER_EXTENSION(data64) // Metal doesn't support 64-bit data buffers yet... +PER_EXTENSION(data64) // Metal doesn't support 64-bit data buffers yet... PER_EXTENSION(adstack) // For keeping the history of mutable local variables From 063542ffe5f16805e9912def6b4e0c0e469faab6 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Mon, 23 Mar 2020 19:50:26 -0400 Subject: [PATCH 03/15] Simplify init in mgpcg_advanced.py (#645) --- examples/mgpcg_advanced.py | 23 +++++++++++------------ python/taichi/lang/transformer.py | 4 ++-- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/examples/mgpcg_advanced.py b/examples/mgpcg_advanced.py index 66b54301d9eff..636c7a3924d6f 100644 --- a/examples/mgpcg_advanced.py +++ b/examples/mgpcg_advanced.py @@ -48,18 +48,17 @@ def __init__(self): @ti.kernel def init(self): - for i, j, k in ti.ndrange((self.N_ext, self.N_tot - self.N_ext), - (self.N_ext, self.N_tot - self.N_ext), - (self.N_ext, self.N_tot - self.N_ext)): - xl = (i - self.N_ext) * 2.0 / self.N_tot - yl = (j - self.N_ext) * 2.0 / self.N_tot - zl = (k - self.N_ext) * 2.0 / self.N_tot - self.r[0][i, j, k] = ti.sin(2.0 * np.pi * xl) * ti.sin( - 2.0 * np.pi * yl) * ti.sin(2.0 * np.pi * zl) - self.z[0][i, j, k] = 0.0 - self.Ap[i, j, k] = 0.0 - self.p[i, j, k] = 0.0 - self.x[i, j, k] = 0.0 + for I in ti.grouped(ti.ndrange((self.N_ext, self.N_tot - self.N_ext), + (self.N_ext, self.N_tot - self.N_ext), + (self.N_ext, self.N_tot - self.N_ext))): + self.r[0][I] = 1.0 + for i in ti.static(range(self.dim)): + self.r[0][I] *= ti.sin(2.0 * np.pi * + (i - self.N_ext) * 2.0 / self.N_tot) + self.z[0][I] = 0.0 + self.Ap[I] = 0.0 + self.p[I] = 0.0 + self.x[I] = 0.0 @ti.func def neighbor_sum(self, x, I): diff --git a/python/taichi/lang/transformer.py b/python/taichi/lang/transformer.py index 0b1ee041f689f..7a6c0bf8efdb8 100644 --- a/python/taichi/lang/transformer.py +++ b/python/taichi/lang/transformer.py @@ -344,10 +344,10 @@ def is_decorated(iter): ''' if is_ndrange_for == 1 else ''' if ti.static(1): __ndrange = 0 - {} = ti.Vector([0] * {}) + {} = ti.Vector([0] * len(__ndrange.dimensions)) for __ndrange_I in range(0): __I = __ndrange_I - '''.format(node.target.id, dim) + '''.format(node.target.id) t = ast.parse(template).body[0] t.body[0].value = node.iter if is_ndrange_for == 1 \ else node.iter.args[0] From b077a71752df582f4236c8cb7bb9834c428550c3 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Mon, 23 Mar 2020 19:57:39 -0400 Subject: [PATCH 04/15] put static/range/struct for into separate functions --- python/taichi/lang/transformer.py | 182 +++++++++++++++++------------- 1 file changed, 103 insertions(+), 79 deletions(-) diff --git a/python/taichi/lang/transformer.py b/python/taichi/lang/transformer.py index 7a6c0bf8efdb8..c439d8f648684 100644 --- a/python/taichi/lang/transformer.py +++ b/python/taichi/lang/transformer.py @@ -288,6 +288,106 @@ def is_decorated(iter): or iter.func.attr == 'grouped' or iter.func.attr == 'ndrange') + def get_targets(): + if isinstance(node.target, ast.Name): + return [node.target.id] + else: + assert isinstance(node.target, ast.Tuple) + return [name.id for name in node.target.elts] + + + def visit_static_for(): + # for i in ti.static(range(n)) + # for i, j in ti.static(ti.ndrange(n)) + # for I in ti.static(ti.grouped(ti.ndrange(n, m))) + self.generic_visit(node, ['body']) + t = self.parse_stmt('if 1: pass; del a') + t.body[0] = node + target = copy.deepcopy(node.target) + target.ctx = ast.Del() + if isinstance(target, ast.Tuple): + for tar in target.elts: + tar.ctx = ast.Del() + t.body[1].targets = [target] + return t + + + def visit_range_for(): + # for i in range(n) + self.generic_visit(node, ['body']) + loop_var = node.target.id + self.check_loop_var(loop_var) + template = ''' +if 1: + {} = ti.Expr(ti.core.make_id_expr('')) + ___begin = ti.Expr(0) + ___end = ti.Expr(0) + ___begin = ti.cast(___begin, ti.i32) + ___end = ti.cast(___end, ti.i32) + ti.core.begin_frontend_range_for({}.ptr, ___begin.ptr, ___end.ptr) + ti.core.end_frontend_range_for() + '''.format(loop_var, loop_var) + t = ast.parse(template).body[0] + + assert len(node.iter.args) in [1, 2] + if len(node.iter.args) == 2: + bgn = node.iter.args[0] + end = node.iter.args[1] + else: + bgn = self.make_constant(value=0) + end = node.iter.args[0] + + t.body[1].value.args[0] = bgn + t.body[2].value.args[0] = end + t.body = t.body[:6] + node.body + t.body[6:] + t.body.append(self.parse_stmt('del {}'.format(loop_var))) + return ast.copy_location(t, node) + + + def visit_struct_for(is_grouped): + # for i, j in x + # for I in ti.grouped(x) + self.generic_visit(node, ['body']) + targets = get_targets() + + for loop_var in targets: + self.check_loop_var(loop_var) + + var_decl = ''.join( + ' {} = ti.Expr(ti.core.make_id_expr(""))\n'.format(name) + for name in targets) # indent: 4 spaces + vars = ', '.join(targets) + if is_grouped: + template = ''' +if 1: + ___loop_var = 0 + {} = ti.make_var_vector(size=___loop_var.loop_range().dim()) + ___expr_group = ti.make_expr_group({}) + ti.core.begin_frontend_struct_for(___expr_group, ___loop_var.loop_range().ptr) + ti.core.end_frontend_range_for() + '''.format(vars, vars) + t = ast.parse(template).body[0] + cut = 4 + t.body[0].value = node.iter + t.body = t.body[:cut] + node.body + t.body[cut:] + else: + template = ''' +if 1: +{} + ___loop_var = 0 + ___expr_group = ti.make_expr_group({}) + ti.core.begin_frontend_struct_for(___expr_group, ___loop_var.loop_range().ptr) + ti.core.end_frontend_range_for() + '''.format(var_decl, vars) + t = ast.parse(template).body[0] + cut = len(targets) + 3 + t.body[cut - 3].value = node.iter + t.body = t.body[:cut] + node.body + t.body[cut:] + for loop_var in reversed(targets): + t.body.append(self.parse_stmt('del {}'.format(loop_var))) + return ast.copy_location(t, node) + + decorated = is_decorated(node.iter) double_decorated = decorated and len(node.iter.args) == 1 \ and is_decorated(node.iter.args[0]) @@ -331,8 +431,6 @@ def is_decorated(iter): ast.fix_missing_locations(node) is_nonstatic_ndrange = is_ndrange_for == 1 or (is_ndrange_for == 2 and is_grouped == 1) - if not is_nonstatic_ndrange: - self.generic_visit(node, ['body']) if is_nonstatic_ndrange: dim = len(node.iter.args) if is_ndrange_for == 1 else len( node.iter.args[0].args) @@ -386,87 +484,13 @@ def is_decorated(iter): node = ast.copy_location(t, node) return self.visit(node) # further translate as a range for elif is_static_for == 1: - t = self.parse_stmt('if 1: pass; del a') - t.body[0] = node - target = copy.deepcopy(node.target) - target.ctx = ast.Del() - if isinstance(target, ast.Tuple): - for tar in target.elts: - tar.ctx = ast.Del() - t.body[1].targets = [target] - return t + return visit_static_for() elif is_range_for == 1: - loop_var = node.target.id - self.check_loop_var(loop_var) - template = ''' -if 1: - {} = ti.Expr(ti.core.make_id_expr('')) - ___begin = ti.Expr(0) - ___end = ti.Expr(0) - ___begin = ti.cast(___begin, ti.i32) - ___end = ti.cast(___end, ti.i32) - ti.core.begin_frontend_range_for({}.ptr, ___begin.ptr, ___end.ptr) - ti.core.end_frontend_range_for() - '''.format(loop_var, loop_var) - t = ast.parse(template).body[0] - - assert len(node.iter.args) in [1, 2] - if len(node.iter.args) == 2: - bgn = node.iter.args[0] - end = node.iter.args[1] - else: - bgn = self.make_constant(value=0) - end = node.iter.args[0] - - t.body[1].value.args[0] = bgn - t.body[2].value.args[0] = end - t.body = t.body[:6] + node.body + t.body[6:] - t.body.append(self.parse_stmt('del {}'.format(loop_var))) - return ast.copy_location(t, node) + return visit_range_for() else: # Struct for assert is_static_for == 0 assert is_ndrange_for == 0 - if isinstance(node.target, ast.Name): - elts = [node.target] - else: - elts = node.target.elts - - for loop_var in elts: - self.check_loop_var(loop_var.id) - - var_decl = ''.join( - ' {} = ti.Expr(ti.core.make_id_expr(""))\n'.format(ind.id) - for ind in elts) # indent: 4 spaces - vars = ', '.join(ind.id for ind in elts) - if is_grouped: - template = ''' -if 1: - ___loop_var = 0 - {} = ti.make_var_vector(size=___loop_var.loop_range().dim()) - ___expr_group = ti.make_expr_group({}) - ti.core.begin_frontend_struct_for(___expr_group, ___loop_var.loop_range().ptr) - ti.core.end_frontend_range_for() - '''.format(vars, vars) - t = ast.parse(template).body[0] - cut = 4 - t.body[0].value = node.iter - t.body = t.body[:cut] + node.body + t.body[cut:] - else: - template = ''' -if 1: -{} - ___loop_var = 0 - ___expr_group = ti.make_expr_group({}) - ti.core.begin_frontend_struct_for(___expr_group, ___loop_var.loop_range().ptr) - ti.core.end_frontend_range_for() - '''.format(var_decl, vars) - t = ast.parse(template).body[0] - cut = len(elts) + 3 - t.body[cut - 3].value = node.iter - t.body = t.body[:cut] + node.body + t.body[cut:] - for loop_var in reversed(elts): - t.body.append(self.parse_stmt('del {}'.format(loop_var.id))) - return ast.copy_location(t, node) + return visit_struct_for(is_grouped) @staticmethod def parse_stmt(stmt): From 1110512e4631c529d88eed2fad002e3c16fc805d Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Mon, 23 Mar 2020 20:30:42 -0400 Subject: [PATCH 05/15] make decorators look clearer --- python/taichi/lang/transformer.py | 210 +++++++++++++++--------------- 1 file changed, 107 insertions(+), 103 deletions(-) diff --git a/python/taichi/lang/transformer.py b/python/taichi/lang/transformer.py index c439d8f648684..b7df34432716d 100644 --- a/python/taichi/lang/transformer.py +++ b/python/taichi/lang/transformer.py @@ -279,14 +279,16 @@ def visit_For(self, node): raise TaichiSyntaxError( "'else' clause for 'for' not supported in Taichi kernels") - def is_decorated(iter): - return isinstance(iter, ast.Call) \ - and isinstance(iter.func, ast.Attribute) \ - and isinstance(iter.func.value, ast.Name) \ + def get_decorator(iter): + if not (isinstance(iter, ast.Call) + and isinstance(iter.func, ast.Attribute) + and isinstance(iter.func.value, ast.Name) and iter.func.value.id == 'ti' and ( iter.func.attr == 'static' or iter.func.attr == 'grouped' - or iter.func.attr == 'ndrange') + or iter.func.attr == 'ndrange')): + return '' + return iter.func.attr def get_targets(): if isinstance(node.target, ast.Name): @@ -344,6 +346,82 @@ def visit_range_for(): return ast.copy_location(t, node) + def visit_ndrange_for(node): + # for i, j in ti.ndrange(n) + template = ''' +if ti.static(1): + __ndrange = 0 + for __ndrange_I in range(0): + __I = __ndrange_I + ''' + t = ast.parse(template).body[0] + t.body[0].value = node.iter + t_loop = t.body[1] + t_loop.iter.args[0] = self.parse_expr( + '__ndrange.acc_dimensions[0]') + targets = get_targets() + targets_tmp = ['__' + name for name in targets] + loop_body = t_loop.body + for i in range(len(targets)): + if i + 1 < len(targets): + stmt = '{} = __I // __ndrange.acc_dimensions[{}]'.format( + targets_tmp[i], i + 1) + else: + stmt = '{} = __I'.format(targets_tmp[i]) + loop_body.append(self.parse_stmt(stmt)) + stmt = '{} = {} + __ndrange.bounds[{}][0]'.format( + targets[i], targets_tmp[i], i) + loop_body.append(self.parse_stmt(stmt)) + if i + 1 < len(targets): + stmt = '__I = __I - {} * __ndrange.acc_dimensions[{}]'.format( + targets_tmp[i], i + 1) + loop_body.append(self.parse_stmt(stmt)) + loop_body += node.body + + node = ast.copy_location(t, node) + return self.visit(node) # further translate as a range for + + + def visit_grouped_ndrange_for(node): + # for I in ti.grouped(ti.ndrange(n, m)) + dim = len(node.iter.args[0].args) + template = ''' +if ti.static(1): + __ndrange = 0 + {} = ti.Vector([0] * len(__ndrange.dimensions)) + for __ndrange_I in range(0): + __I = __ndrange_I + '''.format(node.target.id) + t = ast.parse(template).body[0] + t.body[0].value = node.iter.args[0] + t_loop = t.body[2] + t_loop.iter.args[0] = self.parse_expr( + '__ndrange.acc_dimensions[0]') + targets = ['{}[{}]'.format(node.target.id, i) for i in range(dim)] + targets_tmp = [ + '__{}_{}'.format(node.target.id, i) for i in range(dim) + ] + loop_body = t_loop.body + for i in range(len(targets)): + if i + 1 < len(targets): + stmt = '{} = __I // __ndrange.acc_dimensions[{}]'.format( + targets_tmp[i], i + 1) + else: + stmt = '{} = __I'.format(targets_tmp[i]) + loop_body.append(self.parse_stmt(stmt)) + stmt = '{} = {} + __ndrange.bounds[{}][0]'.format( + targets[i], targets_tmp[i], i) + loop_body.append(self.parse_stmt(stmt)) + if i + 1 < len(targets): + stmt = '__I = __I - {} * __ndrange.acc_dimensions[{}]'.format( + targets_tmp[i], i + 1) + loop_body.append(self.parse_stmt(stmt)) + loop_body += node.body + + node = ast.copy_location(t, node) + return self.visit(node) # further translate as a range for + + def visit_struct_for(is_grouped): # for i, j in x # for I in ti.grouped(x) @@ -388,109 +466,35 @@ def visit_struct_for(is_grouped): return ast.copy_location(t, node) - decorated = is_decorated(node.iter) - double_decorated = decorated and len(node.iter.args) == 1 \ - and is_decorated(node.iter.args[0]) - is_ndrange_for = 0 - is_static_for = 0 - is_grouped = 0 - - if decorated: - attr = node.iter.func - # outer decorator - if attr.attr == 'static': - is_static_for = 1 - elif attr.attr == 'grouped': - is_grouped = 1 - elif attr.attr == 'ndrange': - is_ndrange_for = 1 - else: - raise Exception('Not supported') - if double_decorated: - attr = node.iter.args[0].func - # inner decorator - if attr.attr == 'static': - if is_static_for == 1: - raise TaichiSyntaxError("'ti.static' cannot be nested") - is_static_for = 2 - elif attr.attr == 'grouped': - if is_grouped == 1: - raise TaichiSyntaxError( - "'ti.grouped' cannot be nested") - is_grouped = 2 - elif attr.attr == 'ndrange': - if is_ndrange_for == 1: - raise TaichiSyntaxError( - "'ti.ndrange' cannot be nested") - is_ndrange_for = 2 - else: - raise Exception('Not supported') - - is_range_for = isinstance(node.iter, ast.Call) and isinstance( - node.iter.func, ast.Name) and node.iter.func.id == 'range' + decorator = get_decorator(node.iter) + double_decorator = '' + if decorator != '' and len(node.iter.args) == 1: + double_decorator = get_decorator(node.iter.args[0]) ast.fix_missing_locations(node) - is_nonstatic_ndrange = is_ndrange_for == 1 or (is_ndrange_for == 2 - and is_grouped == 1) - if is_nonstatic_ndrange: - dim = len(node.iter.args) if is_ndrange_for == 1 else len( - node.iter.args[0].args) - template = ''' -if ti.static(1): - __ndrange = 0 - for __ndrange_I in range(0): - __I = __ndrange_I - ''' if is_ndrange_for == 1 else ''' -if ti.static(1): - __ndrange = 0 - {} = ti.Vector([0] * len(__ndrange.dimensions)) - for __ndrange_I in range(0): - __I = __ndrange_I - '''.format(node.target.id) - t = ast.parse(template).body[0] - t.body[0].value = node.iter if is_ndrange_for == 1 \ - else node.iter.args[0] - t_loop = t.body[1] if is_ndrange_for == 1 else t.body[2] - t_loop.iter.args[0] = self.parse_expr( - '__ndrange.acc_dimensions[0]') - targets = node.target - if is_ndrange_for == 1: - if isinstance(targets, ast.Tuple): - targets = [name.id for name in targets.elts] - else: - targets = [targets.id] - targets_tmp = ['__' + name for name in targets] - else: - targets = ['{}[{}]'.format(targets.id, i) for i in range(dim)] - targets_tmp = [ - '__{}_{}'.format(node.target.id, i) for i in range(dim) - ] - loop_body = t_loop.body - for i in range(len(targets)): - if i + 1 < len(targets): - stmt = '{} = __I // __ndrange.acc_dimensions[{}]'.format( - targets_tmp[i], i + 1) - else: - stmt = '{} = __I'.format(targets_tmp[i]) - loop_body.append(self.parse_stmt(stmt)) - stmt = '{} = {} + __ndrange.bounds[{}][0]'.format( - targets[i], targets_tmp[i], i) - loop_body.append(self.parse_stmt(stmt)) - if i + 1 < len(targets): - stmt = '__I = __I - {} * __ndrange.acc_dimensions[{}]'.format( - targets_tmp[i], i + 1) - loop_body.append(self.parse_stmt(stmt)) - loop_body += node.body - node = ast.copy_location(t, node) - return self.visit(node) # further translate as a range for - elif is_static_for == 1: + if decorator == 'static': + if double_decorator == 'static': + raise TaichiSyntaxError("'ti.static' cannot be nested") return visit_static_for() - elif is_range_for == 1: + elif decorator == 'ndrange': + if double_decorator != '': + raise TaichiSyntaxError( + "No decorator is allowed inside 'ti.ndrange") + return visit_ndrange_for(node) + elif decorator == 'grouped': + if double_decorator == 'static': + pass + elif double_decorator == 'ndrange': + return visit_grouped_ndrange_for(node) + elif double_decorator == 'grouped': + raise TaichiSyntaxError("'ti.grouped' cannot be nested") + else: + return visit_struct_for(is_grouped=True) + elif isinstance(node.iter, ast.Call) and isinstance( + node.iter.func, ast.Name) and node.iter.func.id == 'range': return visit_range_for() else: # Struct for - assert is_static_for == 0 - assert is_ndrange_for == 0 - return visit_struct_for(is_grouped) + return visit_struct_for(is_grouped=False) @staticmethod def parse_stmt(stmt): From 75c90f63f89b826763bdf48080f80f0403154bdd Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Mon, 23 Mar 2020 20:42:20 -0400 Subject: [PATCH 06/15] [skip ci] raise syntax error for more cases --- python/taichi/lang/transformer.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/taichi/lang/transformer.py b/python/taichi/lang/transformer.py index b7df34432716d..9e6641323dbdd 100644 --- a/python/taichi/lang/transformer.py +++ b/python/taichi/lang/transformer.py @@ -475,6 +475,11 @@ def visit_struct_for(is_grouped): if decorator == 'static': if double_decorator == 'static': raise TaichiSyntaxError("'ti.static' cannot be nested") + if double_decorator == 'grouped' and ( + len(node.iter.args[0].args) != 1 + or get_decorator(node.iter.args[0].args[0]) == ''): + raise TaichiSyntaxError( + "Static grouped struct for loop is not allowed. Please use 'ti.static(ti.grouped(ti.ndrange(...)))' instead.") return visit_static_for() elif decorator == 'ndrange': if double_decorator != '': @@ -483,7 +488,8 @@ def visit_struct_for(is_grouped): return visit_ndrange_for(node) elif decorator == 'grouped': if double_decorator == 'static': - pass + raise TaichiSyntaxError( + "'ti.static' is not allowed inside 'ti.grouped'") elif double_decorator == 'ndrange': return visit_grouped_ndrange_for(node) elif double_decorator == 'grouped': From d8a3dae9048ab9cc41aa15527d5939746b169be8 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Mon, 23 Mar 2020 21:36:46 -0400 Subject: [PATCH 07/15] [skip ci] dynamic grouped ndrange? --- python/taichi/lang/transformer.py | 48 ++++++++++++++++--------------- 1 file changed, 25 insertions(+), 23 deletions(-) diff --git a/python/taichi/lang/transformer.py b/python/taichi/lang/transformer.py index 9e6641323dbdd..d4dad3f32cb89 100644 --- a/python/taichi/lang/transformer.py +++ b/python/taichi/lang/transformer.py @@ -383,39 +383,41 @@ def visit_ndrange_for(node): def visit_grouped_ndrange_for(node): + from astpretty import pprint # for I in ti.grouped(ti.ndrange(n, m)) - dim = len(node.iter.args[0].args) + target = node.target.id template = ''' if ti.static(1): __ndrange = 0 {} = ti.Vector([0] * len(__ndrange.dimensions)) - for __ndrange_I in range(0): + for __ndrange_I in range(__ndrange.acc_dimensions[0]): __I = __ndrange_I - '''.format(node.target.id) + for __grouped_I in range(len(__ndrange.dimensions)): + __grouped_I_tmp = 0 + if __grouped_I + 1 < len(__ndrange.dimensions): + __grouped_I_tmp = 233 + else: + __grouped_I_tmp = __I + __tmp = 233 + if __grouped_I + 1 < len(__ndrange.dimensions): + __I = 233 + '''.format(target, target) t = ast.parse(template).body[0] + print('ttttttttttttttttttttttttttttttttttttttttttttt') + pprint(t) t.body[0].value = node.iter.args[0] t_loop = t.body[2] - t_loop.iter.args[0] = self.parse_expr( - '__ndrange.acc_dimensions[0]') - targets = ['{}[{}]'.format(node.target.id, i) for i in range(dim)] - targets_tmp = [ - '__{}_{}'.format(node.target.id, i) for i in range(dim) - ] loop_body = t_loop.body - for i in range(len(targets)): - if i + 1 < len(targets): - stmt = '{} = __I // __ndrange.acc_dimensions[{}]'.format( - targets_tmp[i], i + 1) - else: - stmt = '{} = __I'.format(targets_tmp[i]) - loop_body.append(self.parse_stmt(stmt)) - stmt = '{} = {} + __ndrange.bounds[{}][0]'.format( - targets[i], targets_tmp[i], i) - loop_body.append(self.parse_stmt(stmt)) - if i + 1 < len(targets): - stmt = '__I = __I - {} * __ndrange.acc_dimensions[{}]'.format( - targets_tmp[i], i + 1) - loop_body.append(self.parse_stmt(stmt)) + inner_loop_body = loop_body[1].body + inner_loop_body[1].body[0].value = self.parse_expr( + '__I // __ndrange.acc_dimensions[__grouped_I + 1]') + inner_loop_body[2].targets[0] = self.parse_expr( + '{}[__grouped_I]'.format(target)) + inner_loop_body[2].value = self.parse_expr( + '__grouped_I_tmp + __ndrange.bounds[__grouped_I][0]') + inner_loop_body[3].body[0].value = self.parse_expr( + '__I - __grouped_I_tmp * __ndrange.acc_dimensions[__grouped_I + 1]' + ) loop_body += node.body node = ast.copy_location(t, node) From 2a0730059ee511d0420a39ba7f903da8275cf8e2 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Tue, 24 Mar 2020 14:40:05 -0400 Subject: [PATCH 08/15] [skip ci] debug --- python/taichi/lang/transformer.py | 64 ++++++++++++++++++------------- 1 file changed, 38 insertions(+), 26 deletions(-) diff --git a/python/taichi/lang/transformer.py b/python/taichi/lang/transformer.py index d4dad3f32cb89..690f30edf48ca 100644 --- a/python/taichi/lang/transformer.py +++ b/python/taichi/lang/transformer.py @@ -383,45 +383,57 @@ def visit_ndrange_for(node): def visit_grouped_ndrange_for(node): + self.generic_visit(node, ['body']) from astpretty import pprint # for I in ti.grouped(ti.ndrange(n, m)) target = node.target.id template = ''' if ti.static(1): __ndrange = 0 - {} = ti.Vector([0] * len(__ndrange.dimensions)) - for __ndrange_I in range(__ndrange.acc_dimensions[0]): - __I = __ndrange_I - for __grouped_I in range(len(__ndrange.dimensions)): - __grouped_I_tmp = 0 - if __grouped_I + 1 < len(__ndrange.dimensions): - __grouped_I_tmp = 233 - else: - __grouped_I_tmp = __I - __tmp = 233 - if __grouped_I + 1 < len(__ndrange.dimensions): - __I = 233 + {} = ti.expr_init(ti.Vector([0] * len(__ndrange.dimensions))) + ___begin = ti.Expr(0) + ___end = 0 + ___begin = ti.cast(___begin, ti.i32) + ___end = ti.cast(___end, ti.i32) + __ndrange_I = ti.Expr(0) + ti.core.begin_frontend_range_for(__ndrange_I.ptr, ___begin.ptr, ___end.ptr) + __I = __ndrange_I + for __grouped_I in range(len(__ndrange.dimensions)): + __grouped_I_tmp = 0 + if __grouped_I + 1 < len(__ndrange.dimensions): + __grouped_I_tmp = __I // __ndrange.acc_dimensions[__grouped_I + 1] + else: + __grouped_I_tmp = __I + ti.subscript({}, __grouped_I).assign(__grouped_I_tmp + __ndrange.bounds[__grouped_I][0]) + if __grouped_I + 1 < len(__ndrange.dimensions): + __I = __I - __grouped_I_tmp * __ndrange.acc_dimensions[__grouped_I + 1] + ti.core.end_frontend_range_for() '''.format(target, target) t = ast.parse(template).body[0] print('ttttttttttttttttttttttttttttttttttttttttttttt') pprint(t) t.body[0].value = node.iter.args[0] - t_loop = t.body[2] - loop_body = t_loop.body - inner_loop_body = loop_body[1].body - inner_loop_body[1].body[0].value = self.parse_expr( - '__I // __ndrange.acc_dimensions[__grouped_I + 1]') - inner_loop_body[2].targets[0] = self.parse_expr( - '{}[__grouped_I]'.format(target)) - inner_loop_body[2].value = self.parse_expr( - '__grouped_I_tmp + __ndrange.bounds[__grouped_I][0]') - inner_loop_body[3].body[0].value = self.parse_expr( - '__I - __grouped_I_tmp * __ndrange.acc_dimensions[__grouped_I + 1]' - ) - loop_body += node.body + t.body[3].value = self.parse_expr('__ndrange.acc_dimensions[0]') + + # t_loop = t.body[2] + # loop_body = t_loop.body + # inner_loop_body = loop_body[1].body + # inner_loop_body[1].body[0].value = self.parse_expr( + # '__I // __ndrange.acc_dimensions[__grouped_I + 1]') + # inner_loop_body[2].targets[0] = self.parse_expr( + # '{}[__grouped_I]'.format(target)) + # inner_loop_body[2].value = self.parse_expr( + # '__grouped_I_tmp + __ndrange.bounds[__grouped_I][0]') + # inner_loop_body[3].body[0].value = self.parse_expr( + # '__I - __grouped_I_tmp * __ndrange.acc_dimensions[__grouped_I + 1]' + # ) + # loop_body += node.body + + cut = len(t.body) - 1 + t.body = t.body[:cut] + node.body + t.body[cut:] node = ast.copy_location(t, node) - return self.visit(node) # further translate as a range for + return node def visit_struct_for(is_grouped): From 774ef9a1b50eb953d8650be08ea58ac4c5164946 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Tue, 24 Mar 2020 14:49:02 -0400 Subject: [PATCH 09/15] fix --- python/taichi/lang/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/taichi/lang/transformer.py b/python/taichi/lang/transformer.py index 690f30edf48ca..ce5fb09a4ad56 100644 --- a/python/taichi/lang/transformer.py +++ b/python/taichi/lang/transformer.py @@ -395,7 +395,7 @@ def visit_grouped_ndrange_for(node): ___end = 0 ___begin = ti.cast(___begin, ti.i32) ___end = ti.cast(___end, ti.i32) - __ndrange_I = ti.Expr(0) + __ndrange_I = ti.Expr(ti.core.make_id_expr('')) ti.core.begin_frontend_range_for(__ndrange_I.ptr, ___begin.ptr, ___end.ptr) __I = __ndrange_I for __grouped_I in range(len(__ndrange.dimensions)): From 72c4e4bc82d08c9e1add5a5b93534eaac13c3c12 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Tue, 24 Mar 2020 14:57:25 -0400 Subject: [PATCH 10/15] add more tests --- python/taichi/lang/transformer.py | 27 ++------------- tests/python/test_grouped.py | 54 +++++++++++++++++++++++++++++- tests/python/test_syntax_errors.py | 14 ++++++++ 3 files changed, 70 insertions(+), 25 deletions(-) diff --git a/python/taichi/lang/transformer.py b/python/taichi/lang/transformer.py index ce5fb09a4ad56..2336b5b944fd6 100644 --- a/python/taichi/lang/transformer.py +++ b/python/taichi/lang/transformer.py @@ -383,16 +383,15 @@ def visit_ndrange_for(node): def visit_grouped_ndrange_for(node): - self.generic_visit(node, ['body']) - from astpretty import pprint # for I in ti.grouped(ti.ndrange(n, m)) + self.generic_visit(node, ['body']) target = node.target.id template = ''' if ti.static(1): __ndrange = 0 {} = ti.expr_init(ti.Vector([0] * len(__ndrange.dimensions))) ___begin = ti.Expr(0) - ___end = 0 + ___end = __ndrange.acc_dimensions[0] ___begin = ti.cast(___begin, ti.i32) ___end = ti.cast(___end, ti.i32) __ndrange_I = ti.Expr(ti.core.make_id_expr('')) @@ -410,30 +409,10 @@ def visit_grouped_ndrange_for(node): ti.core.end_frontend_range_for() '''.format(target, target) t = ast.parse(template).body[0] - print('ttttttttttttttttttttttttttttttttttttttttttttt') - pprint(t) t.body[0].value = node.iter.args[0] - t.body[3].value = self.parse_expr('__ndrange.acc_dimensions[0]') - - # t_loop = t.body[2] - # loop_body = t_loop.body - # inner_loop_body = loop_body[1].body - # inner_loop_body[1].body[0].value = self.parse_expr( - # '__I // __ndrange.acc_dimensions[__grouped_I + 1]') - # inner_loop_body[2].targets[0] = self.parse_expr( - # '{}[__grouped_I]'.format(target)) - # inner_loop_body[2].value = self.parse_expr( - # '__grouped_I_tmp + __ndrange.bounds[__grouped_I][0]') - # inner_loop_body[3].body[0].value = self.parse_expr( - # '__I - __grouped_I_tmp * __ndrange.acc_dimensions[__grouped_I + 1]' - # ) - # loop_body += node.body - cut = len(t.body) - 1 t.body = t.body[:cut] + node.body + t.body[cut:] - - node = ast.copy_location(t, node) - return node + return ast.copy_location(t, node) def visit_struct_for(is_grouped): diff --git a/tests/python/test_grouped.py b/tests/python/test_grouped.py index 26339ff69888f..9a1c92d0c0b24 100644 --- a/tests/python/test_grouped.py +++ b/tests/python/test_grouped.py @@ -81,6 +81,58 @@ def test(): j * 2 if x0 <= i < y0 and x1 <= j < y1 else 0) +@ti.all_archs +def test_static_grouped_ndrange(): + val = ti.var(ti.i32) + + n = 4 + m = 8 + + ti.root.dense(ti.ij, (n, m)).place(val) + + x0 = 2 + y0 = 3 + x1 = 1 + y1 = 6 + + @ti.kernel + def test(): + for I in ti.static(ti.grouped(ti.ndrange((x0, y0), (x1, y1)))): + val[I] = I[0] + I[1] * 2 + + test() + + for i in range(n): + for j in range(m): + assert val[i, j] == (i + + j * 2 if x0 <= i < y0 and x1 <= j < y1 else 0) + + +@ti.all_archs +def test_grouped_ndrange_starred(): + val = ti.var(ti.i32) + + n = 4 + m = 8 + p = 16 + dim = 3 + + ti.root.dense(ti.ijk, (n, m, p)).place(val) + + @ti.kernel + def test(): + for I in ti.grouped(ti.ndrange(*(((0, n),) * dim))): + val[I] = I[0] + I[1] * 2 + I[2] * 3 + + test() + + for i in range(n): + for j in range(m): + for k in range(p): + assert val[i, j, k] == (i + j * 2 + k * 3 + if j < n and k < n else 0) + + @ti.all_archs def test_grouped_ndrange_0d(): val = ti.var(ti.i32, shape=()) @@ -96,7 +148,7 @@ def test(): @ti.all_archs -def test_grouped_ndrange_0d(): +def test_static_grouped_ndrange_0d(): val = ti.var(ti.i32, shape=()) @ti.kernel diff --git a/tests/python/test_syntax_errors.py b/tests/python/test_syntax_errors.py index 92fa0ee1bb5ad..66bae69375a2e 100644 --- a/tests/python/test_syntax_errors.py +++ b/tests/python/test_syntax_errors.py @@ -210,3 +210,17 @@ def func(): pass func() + + +@ti.must_throw(ti.TaichiSyntaxError) +def test_static_grouped_struct_for(): + val = ti.var(ti.i32) + + ti.root.dense(ti.ij, (1, 1)).place(val) + + @ti.kernel + def test(): + for I in ti.static(ti.grouped(val)): + pass + + test() From 02309ad85a331a791a89c2d6dd1cc23ba3f9dff8 Mon Sep 17 00:00:00 2001 From: Taichi Gardener Date: Tue, 24 Mar 2020 15:07:14 -0400 Subject: [PATCH 11/15] [skip ci] enforce code format --- CMakeLists.txt | 5 ++--- examples/mgpcg_advanced.py | 11 ++++++----- python/taichi/lang/transformer.py | 20 +++++++------------- tests/python/test_grouped.py | 6 +++--- 4 files changed, 18 insertions(+), 24 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 095a7919af566..e8b99e29e35b3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,12 +11,12 @@ SET(TI_VERSION_MINOR 5) SET(TI_VERSION_PATCH 8) execute_process( - WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} + WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} COMMAND git rev-parse --short HEAD RESULT_VARIABLE SHORT_HASH_RESULT OUTPUT_VARIABLE TI_COMMIT_SHORT_HASH) execute_process( - WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} + WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} COMMAND git rev-parse HEAD RESULT_VARIABLE SHORT_HASH_RESULT OUTPUT_VARIABLE TI_COMMIT_HASH) @@ -71,4 +71,3 @@ FILE(WRITE ${CMAKE_CURRENT_LIST_DIR}/taichi/common/version.h "#define TI_CUDAVERSION \"${CUDA_VERSION}\"\n" "#define TI_CUDAROOT_DIR \"${CUDA_TOOLKIT_ROOT_DIR}\"\n" ) - diff --git a/examples/mgpcg_advanced.py b/examples/mgpcg_advanced.py index 636c7a3924d6f..a5296f12de09d 100644 --- a/examples/mgpcg_advanced.py +++ b/examples/mgpcg_advanced.py @@ -48,13 +48,14 @@ def __init__(self): @ti.kernel def init(self): - for I in ti.grouped(ti.ndrange((self.N_ext, self.N_tot - self.N_ext), - (self.N_ext, self.N_tot - self.N_ext), - (self.N_ext, self.N_tot - self.N_ext))): + for I in ti.grouped( + ti.ndrange((self.N_ext, self.N_tot - self.N_ext), + (self.N_ext, self.N_tot - self.N_ext), + (self.N_ext, self.N_tot - self.N_ext))): self.r[0][I] = 1.0 for i in ti.static(range(self.dim)): - self.r[0][I] *= ti.sin(2.0 * np.pi * - (i - self.N_ext) * 2.0 / self.N_tot) + self.r[0][I] *= ti.sin(2.0 * np.pi * (i - self.N_ext) * 2.0 / + self.N_tot) self.z[0][I] = 0.0 self.Ap[I] = 0.0 self.p[I] = 0.0 diff --git a/python/taichi/lang/transformer.py b/python/taichi/lang/transformer.py index 2336b5b944fd6..ccf4bddddffff 100644 --- a/python/taichi/lang/transformer.py +++ b/python/taichi/lang/transformer.py @@ -281,12 +281,11 @@ def visit_For(self, node): def get_decorator(iter): if not (isinstance(iter, ast.Call) - and isinstance(iter.func, ast.Attribute) - and isinstance(iter.func.value, ast.Name) - and iter.func.value.id == 'ti' and ( - iter.func.attr == 'static' - or iter.func.attr == 'grouped' - or iter.func.attr == 'ndrange')): + and isinstance(iter.func, ast.Attribute) + and isinstance(iter.func.value, ast.Name) + and iter.func.value.id == 'ti' and + (iter.func.attr == 'static' or iter.func.attr == 'grouped' + or iter.func.attr == 'ndrange')): return '' return iter.func.attr @@ -297,7 +296,6 @@ def get_targets(): assert isinstance(node.target, ast.Tuple) return [name.id for name in node.target.elts] - def visit_static_for(): # for i in ti.static(range(n)) # for i, j in ti.static(ti.ndrange(n)) @@ -313,7 +311,6 @@ def visit_static_for(): t.body[1].targets = [target] return t - def visit_range_for(): # for i in range(n) self.generic_visit(node, ['body']) @@ -345,7 +342,6 @@ def visit_range_for(): t.body.append(self.parse_stmt('del {}'.format(loop_var))) return ast.copy_location(t, node) - def visit_ndrange_for(node): # for i, j in ti.ndrange(n) template = ''' @@ -381,7 +377,6 @@ def visit_ndrange_for(node): node = ast.copy_location(t, node) return self.visit(node) # further translate as a range for - def visit_grouped_ndrange_for(node): # for I in ti.grouped(ti.ndrange(n, m)) self.generic_visit(node, ['body']) @@ -414,7 +409,6 @@ def visit_grouped_ndrange_for(node): t.body = t.body[:cut] + node.body + t.body[cut:] return ast.copy_location(t, node) - def visit_struct_for(is_grouped): # for i, j in x # for I in ti.grouped(x) @@ -458,7 +452,6 @@ def visit_struct_for(is_grouped): t.body.append(self.parse_stmt('del {}'.format(loop_var))) return ast.copy_location(t, node) - decorator = get_decorator(node.iter) double_decorator = '' if decorator != '' and len(node.iter.args) == 1: @@ -472,7 +465,8 @@ def visit_struct_for(is_grouped): len(node.iter.args[0].args) != 1 or get_decorator(node.iter.args[0].args[0]) == ''): raise TaichiSyntaxError( - "Static grouped struct for loop is not allowed. Please use 'ti.static(ti.grouped(ti.ndrange(...)))' instead.") + "Static grouped struct for loop is not allowed. Please use 'ti.static(ti.grouped(ti.ndrange(...)))' instead." + ) return visit_static_for() elif decorator == 'ndrange': if double_decorator != '': diff --git a/tests/python/test_grouped.py b/tests/python/test_grouped.py index 9a1c92d0c0b24..c18d1cde04e7c 100644 --- a/tests/python/test_grouped.py +++ b/tests/python/test_grouped.py @@ -121,7 +121,7 @@ def test_grouped_ndrange_starred(): @ti.kernel def test(): - for I in ti.grouped(ti.ndrange(*(((0, n),) * dim))): + for I in ti.grouped(ti.ndrange(*(((0, n), ) * dim))): val[I] = I[0] + I[1] * 2 + I[2] * 3 test() @@ -129,8 +129,8 @@ def test(): for i in range(n): for j in range(m): for k in range(p): - assert val[i, j, k] == (i + j * 2 + k * 3 - if j < n and k < n else 0) + assert val[i, j, + k] == (i + j * 2 + k * 3 if j < n and k < n else 0) @ti.all_archs From 569b7ed59569df9ce71d297cce2291ca1cd15c8c Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Tue, 24 Mar 2020 15:18:24 -0400 Subject: [PATCH 12/15] further simplify mgpcg_advanced.py and make it able to run with self.dim=2 --- examples/mgpcg_advanced.py | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/examples/mgpcg_advanced.py b/examples/mgpcg_advanced.py index 636c7a3924d6f..efbec5bc9d52e 100644 --- a/examples/mgpcg_advanced.py +++ b/examples/mgpcg_advanced.py @@ -35,22 +35,20 @@ def __init__(self): self.pixels = ti.var(dt=real, shape=(self.N_gui, self.N_gui)) # image buffer - self.grid = ti.root.pointer(ti.ijk, [self.N_tot // 4]).dense( - ti.ijk, 4).place(self.x, self.p, self.Ap) + ijk = ti.ijk if self.dim == 3 else ti.ij + self.grid = ti.root.pointer(ijk, [self.N_tot // 4]).dense( + ijk, 4).place(self.x, self.p, self.Ap) for l in range(self.n_mg_levels): - self.grid = ti.root.pointer(ti.ijk, - [self.N_tot // (4 * 2**l)]).dense( - ti.ijk, - 4).place(self.r[l], self.z[l]) + self.grid = ti.root.pointer(ijk, [self.N_tot // (4 * 2**l)]).dense( + ijk, 4).place(self.r[l], self.z[l]) ti.root.place(self.alpha, self.beta, self.sum) @ti.kernel def init(self): - for I in ti.grouped(ti.ndrange((self.N_ext, self.N_tot - self.N_ext), - (self.N_ext, self.N_tot - self.N_ext), - (self.N_ext, self.N_tot - self.N_ext))): + for I in ti.grouped(ti.ndrange( + *((self.N_ext, self.N_tot - self.N_ext), ) * self.dim)): self.r[0][I] = 1.0 for i in ti.static(range(self.dim)): self.r[0][I] *= ti.sin(2.0 * np.pi * @@ -98,7 +96,7 @@ def update_p(self): @ti.kernel def restrict(self, l: ti.template()): for I in ti.grouped(self.r[l]): - res = self.r[l][I] - (6.0 * self.z[l][I] - + res = self.r[l][I] - (2.0 * self.dim * self.z[l][I] - self.neighbor_sum(self.z[l], I)) self.r[l + 1][I // 2] += res * 0.5 @@ -112,8 +110,8 @@ def smooth(self, l: ti.template(), phase: ti.template()): # phase = red/black Gauss-Seidel phase for I in ti.grouped(self.r[l]): if (I.sum()) & 1 == phase: - self.z[l][I] = (self.r[l][I] + - self.neighbor_sum(self.z[l], I)) / 6.0 + self.z[l][I] = (self.r[l][I] + self.neighbor_sum( + self.z[l], I)) / (2.0 * self.dim) def apply_preconditioner(self): self.z[0].fill(0) @@ -137,11 +135,12 @@ def apply_preconditioner(self): @ti.kernel def paint(self): - kk = self.N_tot * 3 // 8 - for i, j in self.pixels: - ii = int(i * self.N / self.N_gui) + self.N_ext - jj = int(j * self.N / self.N_gui) + self.N_ext - self.pixels[i, j] = self.x[ii, jj, kk] / self.N_tot + if ti.static(self.dim == 3): + kk = self.N_tot * 3 // 8 + for i, j in self.pixels: + ii = int(i * self.N / self.N_gui) + self.N_ext + jj = int(j * self.N / self.N_gui) + self.N_ext + self.pixels[i, j] = self.x[ii, jj, kk] / self.N_tot def run(self): gui = ti.GUI("Multigrid Preconditioned Conjugate Gradients", From 306152bf0f79d8121b39b0bafe5a7a6602eaf925 Mon Sep 17 00:00:00 2001 From: Taichi Gardener Date: Tue, 24 Mar 2020 15:25:41 -0400 Subject: [PATCH 13/15] [skip ci] enforce code format --- examples/mgpcg_advanced.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/examples/mgpcg_advanced.py b/examples/mgpcg_advanced.py index a413f80d076d4..8019b78c18ac2 100644 --- a/examples/mgpcg_advanced.py +++ b/examples/mgpcg_advanced.py @@ -36,8 +36,9 @@ def __init__(self): shape=(self.N_gui, self.N_gui)) # image buffer ijk = ti.ijk if self.dim == 3 else ti.ij - self.grid = ti.root.pointer(ijk, [self.N_tot // 4]).dense( - ijk, 4).place(self.x, self.p, self.Ap) + self.grid = ti.root.pointer(ijk, + [self.N_tot // 4]).dense(ijk, 4).place( + self.x, self.p, self.Ap) for l in range(self.n_mg_levels): self.grid = ti.root.pointer(ijk, [self.N_tot // (4 * 2**l)]).dense( @@ -47,8 +48,9 @@ def __init__(self): @ti.kernel def init(self): - for I in ti.grouped(ti.ndrange( - *((self.N_ext, self.N_tot - self.N_ext), ) * self.dim)): + for I in ti.grouped( + ti.ndrange(*( + (self.N_ext, self.N_tot - self.N_ext), ) * self.dim)): self.r[0][I] = 1.0 for i in ti.static(range(self.dim)): self.r[0][I] *= ti.sin(2.0 * np.pi * (i - self.N_ext) * 2.0 / @@ -111,7 +113,7 @@ def smooth(self, l: ti.template(), phase: ti.template()): for I in ti.grouped(self.r[l]): if (I.sum()) & 1 == phase: self.z[l][I] = (self.r[l][I] + self.neighbor_sum( - self.z[l], I)) / (2.0 * self.dim) + self.z[l], I)) / (2.0 * self.dim) def apply_preconditioner(self): self.z[0].fill(0) From 224663ba3bb5057798af69d565942bb6a083753c Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Tue, 24 Mar 2020 17:11:15 -0400 Subject: [PATCH 14/15] [skip ci] rename ijk to indices --- examples/mgpcg_advanced.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/examples/mgpcg_advanced.py b/examples/mgpcg_advanced.py index 8019b78c18ac2..bf714857a0b07 100644 --- a/examples/mgpcg_advanced.py +++ b/examples/mgpcg_advanced.py @@ -35,14 +35,16 @@ def __init__(self): self.pixels = ti.var(dt=real, shape=(self.N_gui, self.N_gui)) # image buffer - ijk = ti.ijk if self.dim == 3 else ti.ij - self.grid = ti.root.pointer(ijk, - [self.N_tot // 4]).dense(ijk, 4).place( + indices = ti.ijk if self.dim == 3 else ti.ij + self.grid = ti.root.pointer(indices, + [self.N_tot // 4]).dense(indices, 4).place( self.x, self.p, self.Ap) for l in range(self.n_mg_levels): - self.grid = ti.root.pointer(ijk, [self.N_tot // (4 * 2**l)]).dense( - ijk, 4).place(self.r[l], self.z[l]) + self.grid = ti.root.pointer(indices, + [self.N_tot // (4 * 2**l)]).dense( + indices, 4).place( + self.r[l], self.z[l]) ti.root.place(self.alpha, self.beta, self.sum) From 25f3c7d99a4d828b4960507f3c7eb000068de581 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Tue, 24 Mar 2020 17:19:02 -0400 Subject: [PATCH 15/15] move components of visit_For outside --- python/taichi/lang/transformer.py | 277 +++++++++++++++--------------- 1 file changed, 139 insertions(+), 138 deletions(-) diff --git a/python/taichi/lang/transformer.py b/python/taichi/lang/transformer.py index ccf4bddddffff..3a4b6831a3dea 100644 --- a/python/taichi/lang/transformer.py +++ b/python/taichi/lang/transformer.py @@ -274,49 +274,46 @@ def check_loop_var(self, loop_var): "Variable '{}' is already declared in the outer scope and cannot be used as loop variable" .format(loop_var)) - def visit_For(self, node): - if node.orelse: - raise TaichiSyntaxError( - "'else' clause for 'for' not supported in Taichi kernels") + @staticmethod + def get_decorator(iter): + if not (isinstance(iter, ast.Call) + and isinstance(iter.func, ast.Attribute) + and isinstance(iter.func.value, ast.Name) + and iter.func.value.id == 'ti' and + (iter.func.attr == 'static' or iter.func.attr == 'grouped' + or iter.func.attr == 'ndrange')): + return '' + return iter.func.attr - def get_decorator(iter): - if not (isinstance(iter, ast.Call) - and isinstance(iter.func, ast.Attribute) - and isinstance(iter.func.value, ast.Name) - and iter.func.value.id == 'ti' and - (iter.func.attr == 'static' or iter.func.attr == 'grouped' - or iter.func.attr == 'ndrange')): - return '' - return iter.func.attr - - def get_targets(): - if isinstance(node.target, ast.Name): - return [node.target.id] - else: - assert isinstance(node.target, ast.Tuple) - return [name.id for name in node.target.elts] - - def visit_static_for(): - # for i in ti.static(range(n)) - # for i, j in ti.static(ti.ndrange(n)) - # for I in ti.static(ti.grouped(ti.ndrange(n, m))) - self.generic_visit(node, ['body']) - t = self.parse_stmt('if 1: pass; del a') - t.body[0] = node - target = copy.deepcopy(node.target) - target.ctx = ast.Del() - if isinstance(target, ast.Tuple): - for tar in target.elts: - tar.ctx = ast.Del() - t.body[1].targets = [target] - return t - - def visit_range_for(): - # for i in range(n) - self.generic_visit(node, ['body']) - loop_var = node.target.id - self.check_loop_var(loop_var) - template = ''' + @staticmethod + def get_targets(node): + if isinstance(node.target, ast.Name): + return [node.target.id] + else: + assert isinstance(node.target, ast.Tuple) + return [name.id for name in node.target.elts] + + def visit_static_for(self, node): + # for i in ti.static(range(n)) + # for i, j in ti.static(ti.ndrange(n)) + # for I in ti.static(ti.grouped(ti.ndrange(n, m))) + self.generic_visit(node, ['body']) + t = self.parse_stmt('if 1: pass; del a') + t.body[0] = node + target = copy.deepcopy(node.target) + target.ctx = ast.Del() + if isinstance(target, ast.Tuple): + for tar in target.elts: + tar.ctx = ast.Del() + t.body[1].targets = [target] + return t + + def visit_range_for(self, node): + # for i in range(n) + self.generic_visit(node, ['body']) + loop_var = node.target.id + self.check_loop_var(loop_var) + template = ''' if 1: {} = ti.Expr(ti.core.make_id_expr('')) ___begin = ti.Expr(0) @@ -325,63 +322,62 @@ def visit_range_for(): ___end = ti.cast(___end, ti.i32) ti.core.begin_frontend_range_for({}.ptr, ___begin.ptr, ___end.ptr) ti.core.end_frontend_range_for() - '''.format(loop_var, loop_var) - t = ast.parse(template).body[0] + '''.format(loop_var, loop_var) + t = ast.parse(template).body[0] - assert len(node.iter.args) in [1, 2] - if len(node.iter.args) == 2: - bgn = node.iter.args[0] - end = node.iter.args[1] - else: - bgn = self.make_constant(value=0) - end = node.iter.args[0] + assert len(node.iter.args) in [1, 2] + if len(node.iter.args) == 2: + bgn = node.iter.args[0] + end = node.iter.args[1] + else: + bgn = self.make_constant(value=0) + end = node.iter.args[0] - t.body[1].value.args[0] = bgn - t.body[2].value.args[0] = end - t.body = t.body[:6] + node.body + t.body[6:] - t.body.append(self.parse_stmt('del {}'.format(loop_var))) - return ast.copy_location(t, node) + t.body[1].value.args[0] = bgn + t.body[2].value.args[0] = end + t.body = t.body[:6] + node.body + t.body[6:] + t.body.append(self.parse_stmt('del {}'.format(loop_var))) + return ast.copy_location(t, node) - def visit_ndrange_for(node): - # for i, j in ti.ndrange(n) - template = ''' + def visit_ndrange_for(self, node): + # for i, j in ti.ndrange(n) + template = ''' if ti.static(1): __ndrange = 0 for __ndrange_I in range(0): __I = __ndrange_I - ''' - t = ast.parse(template).body[0] - t.body[0].value = node.iter - t_loop = t.body[1] - t_loop.iter.args[0] = self.parse_expr( - '__ndrange.acc_dimensions[0]') - targets = get_targets() - targets_tmp = ['__' + name for name in targets] - loop_body = t_loop.body - for i in range(len(targets)): - if i + 1 < len(targets): - stmt = '{} = __I // __ndrange.acc_dimensions[{}]'.format( - targets_tmp[i], i + 1) - else: - stmt = '{} = __I'.format(targets_tmp[i]) - loop_body.append(self.parse_stmt(stmt)) - stmt = '{} = {} + __ndrange.bounds[{}][0]'.format( - targets[i], targets_tmp[i], i) + ''' + t = ast.parse(template).body[0] + t.body[0].value = node.iter + t_loop = t.body[1] + t_loop.iter.args[0] = self.parse_expr('__ndrange.acc_dimensions[0]') + targets = self.get_targets(node) + targets_tmp = ['__' + name for name in targets] + loop_body = t_loop.body + for i in range(len(targets)): + if i + 1 < len(targets): + stmt = '{} = __I // __ndrange.acc_dimensions[{}]'.format( + targets_tmp[i], i + 1) + else: + stmt = '{} = __I'.format(targets_tmp[i]) + loop_body.append(self.parse_stmt(stmt)) + stmt = '{} = {} + __ndrange.bounds[{}][0]'.format( + targets[i], targets_tmp[i], i) + loop_body.append(self.parse_stmt(stmt)) + if i + 1 < len(targets): + stmt = '__I = __I - {} * __ndrange.acc_dimensions[{}]'.format( + targets_tmp[i], i + 1) loop_body.append(self.parse_stmt(stmt)) - if i + 1 < len(targets): - stmt = '__I = __I - {} * __ndrange.acc_dimensions[{}]'.format( - targets_tmp[i], i + 1) - loop_body.append(self.parse_stmt(stmt)) - loop_body += node.body - - node = ast.copy_location(t, node) - return self.visit(node) # further translate as a range for - - def visit_grouped_ndrange_for(node): - # for I in ti.grouped(ti.ndrange(n, m)) - self.generic_visit(node, ['body']) - target = node.target.id - template = ''' + loop_body += node.body + + node = ast.copy_location(t, node) + return self.visit(node) # further translate as a range for + + def visit_grouped_ndrange_for(self, node): + # for I in ti.grouped(ti.ndrange(n, m)) + self.generic_visit(node, ['body']) + target = node.target.id + template = ''' if ti.static(1): __ndrange = 0 {} = ti.expr_init(ti.Vector([0] * len(__ndrange.dimensions))) @@ -402,60 +398,65 @@ def visit_grouped_ndrange_for(node): if __grouped_I + 1 < len(__ndrange.dimensions): __I = __I - __grouped_I_tmp * __ndrange.acc_dimensions[__grouped_I + 1] ti.core.end_frontend_range_for() - '''.format(target, target) - t = ast.parse(template).body[0] - t.body[0].value = node.iter.args[0] - cut = len(t.body) - 1 - t.body = t.body[:cut] + node.body + t.body[cut:] - return ast.copy_location(t, node) - - def visit_struct_for(is_grouped): - # for i, j in x - # for I in ti.grouped(x) - self.generic_visit(node, ['body']) - targets = get_targets() - - for loop_var in targets: - self.check_loop_var(loop_var) - - var_decl = ''.join( - ' {} = ti.Expr(ti.core.make_id_expr(""))\n'.format(name) - for name in targets) # indent: 4 spaces - vars = ', '.join(targets) - if is_grouped: - template = ''' + '''.format(target, target) + t = ast.parse(template).body[0] + t.body[0].value = node.iter.args[0] + cut = len(t.body) - 1 + t.body = t.body[:cut] + node.body + t.body[cut:] + return ast.copy_location(t, node) + + def visit_struct_for(self, node, is_grouped): + # for i, j in x + # for I in ti.grouped(x) + self.generic_visit(node, ['body']) + targets = self.get_targets(node) + + for loop_var in targets: + self.check_loop_var(loop_var) + + var_decl = ''.join( + ' {} = ti.Expr(ti.core.make_id_expr(""))\n'.format(name) + for name in targets) # indent: 4 spaces + vars = ', '.join(targets) + if is_grouped: + template = ''' if 1: ___loop_var = 0 {} = ti.make_var_vector(size=___loop_var.loop_range().dim()) ___expr_group = ti.make_expr_group({}) ti.core.begin_frontend_struct_for(___expr_group, ___loop_var.loop_range().ptr) ti.core.end_frontend_range_for() - '''.format(vars, vars) - t = ast.parse(template).body[0] - cut = 4 - t.body[0].value = node.iter - t.body = t.body[:cut] + node.body + t.body[cut:] - else: - template = ''' + '''.format(vars, vars) + t = ast.parse(template).body[0] + cut = 4 + t.body[0].value = node.iter + t.body = t.body[:cut] + node.body + t.body[cut:] + else: + template = ''' if 1: {} ___loop_var = 0 ___expr_group = ti.make_expr_group({}) ti.core.begin_frontend_struct_for(___expr_group, ___loop_var.loop_range().ptr) ti.core.end_frontend_range_for() - '''.format(var_decl, vars) - t = ast.parse(template).body[0] - cut = len(targets) + 3 - t.body[cut - 3].value = node.iter - t.body = t.body[:cut] + node.body + t.body[cut:] - for loop_var in reversed(targets): - t.body.append(self.parse_stmt('del {}'.format(loop_var))) - return ast.copy_location(t, node) - - decorator = get_decorator(node.iter) + '''.format(var_decl, vars) + t = ast.parse(template).body[0] + cut = len(targets) + 3 + t.body[cut - 3].value = node.iter + t.body = t.body[:cut] + node.body + t.body[cut:] + for loop_var in reversed(targets): + t.body.append(self.parse_stmt('del {}'.format(loop_var))) + return ast.copy_location(t, node) + + def visit_For(self, node): + if node.orelse: + raise TaichiSyntaxError( + "'else' clause for 'for' not supported in Taichi kernels") + + decorator = self.get_decorator(node.iter) double_decorator = '' if decorator != '' and len(node.iter.args) == 1: - double_decorator = get_decorator(node.iter.args[0]) + double_decorator = self.get_decorator(node.iter.args[0]) ast.fix_missing_locations(node) if decorator == 'static': @@ -463,31 +464,31 @@ def visit_struct_for(is_grouped): raise TaichiSyntaxError("'ti.static' cannot be nested") if double_decorator == 'grouped' and ( len(node.iter.args[0].args) != 1 - or get_decorator(node.iter.args[0].args[0]) == ''): + or self.get_decorator(node.iter.args[0].args[0]) == ''): raise TaichiSyntaxError( "Static grouped struct for loop is not allowed. Please use 'ti.static(ti.grouped(ti.ndrange(...)))' instead." ) - return visit_static_for() + return self.visit_static_for(node) elif decorator == 'ndrange': if double_decorator != '': raise TaichiSyntaxError( "No decorator is allowed inside 'ti.ndrange") - return visit_ndrange_for(node) + return self.visit_ndrange_for(node) elif decorator == 'grouped': if double_decorator == 'static': raise TaichiSyntaxError( "'ti.static' is not allowed inside 'ti.grouped'") elif double_decorator == 'ndrange': - return visit_grouped_ndrange_for(node) + return self.visit_grouped_ndrange_for(node) elif double_decorator == 'grouped': raise TaichiSyntaxError("'ti.grouped' cannot be nested") else: - return visit_struct_for(is_grouped=True) + return self.visit_struct_for(node, is_grouped=True) elif isinstance(node.iter, ast.Call) and isinstance( node.iter.func, ast.Name) and node.iter.func.id == 'range': - return visit_range_for() + return self.visit_range_for(node) else: # Struct for - return visit_struct_for(is_grouped=False) + return self.visit_struct_for(node, is_grouped=False) @staticmethod def parse_stmt(stmt):