diff --git a/CMakeLists.txt b/CMakeLists.txt index 095a7919af566..e8b99e29e35b3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,12 +11,12 @@ SET(TI_VERSION_MINOR 5) SET(TI_VERSION_PATCH 8) execute_process( - WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} + WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} COMMAND git rev-parse --short HEAD RESULT_VARIABLE SHORT_HASH_RESULT OUTPUT_VARIABLE TI_COMMIT_SHORT_HASH) execute_process( - WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} + WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} COMMAND git rev-parse HEAD RESULT_VARIABLE SHORT_HASH_RESULT OUTPUT_VARIABLE TI_COMMIT_HASH) @@ -71,4 +71,3 @@ FILE(WRITE ${CMAKE_CURRENT_LIST_DIR}/taichi/common/version.h "#define TI_CUDAVERSION \"${CUDA_VERSION}\"\n" "#define TI_CUDAROOT_DIR \"${CUDA_TOOLKIT_ROOT_DIR}\"\n" ) - diff --git a/examples/mgpcg_advanced.py b/examples/mgpcg_advanced.py index 636c7a3924d6f..bf714857a0b07 100644 --- a/examples/mgpcg_advanced.py +++ b/examples/mgpcg_advanced.py @@ -35,26 +35,28 @@ def __init__(self): self.pixels = ti.var(dt=real, shape=(self.N_gui, self.N_gui)) # image buffer - self.grid = ti.root.pointer(ti.ijk, [self.N_tot // 4]).dense( - ti.ijk, 4).place(self.x, self.p, self.Ap) + indices = ti.ijk if self.dim == 3 else ti.ij + self.grid = ti.root.pointer(indices, + [self.N_tot // 4]).dense(indices, 4).place( + self.x, self.p, self.Ap) for l in range(self.n_mg_levels): - self.grid = ti.root.pointer(ti.ijk, + self.grid = ti.root.pointer(indices, [self.N_tot // (4 * 2**l)]).dense( - ti.ijk, - 4).place(self.r[l], self.z[l]) + indices, 4).place( + self.r[l], self.z[l]) ti.root.place(self.alpha, self.beta, self.sum) @ti.kernel def init(self): - for I in ti.grouped(ti.ndrange((self.N_ext, self.N_tot - self.N_ext), - (self.N_ext, self.N_tot - self.N_ext), - (self.N_ext, self.N_tot - self.N_ext))): + for I in ti.grouped( + ti.ndrange(*( + (self.N_ext, self.N_tot - self.N_ext), ) * self.dim)): self.r[0][I] = 1.0 for i in ti.static(range(self.dim)): - self.r[0][I] *= ti.sin(2.0 * np.pi * - (i - self.N_ext) * 2.0 / self.N_tot) + self.r[0][I] *= ti.sin(2.0 * np.pi * (i - self.N_ext) * 2.0 / + self.N_tot) self.z[0][I] = 0.0 self.Ap[I] = 0.0 self.p[I] = 0.0 @@ -98,7 +100,7 @@ def update_p(self): @ti.kernel def restrict(self, l: ti.template()): for I in ti.grouped(self.r[l]): - res = self.r[l][I] - (6.0 * self.z[l][I] - + res = self.r[l][I] - (2.0 * self.dim * self.z[l][I] - self.neighbor_sum(self.z[l], I)) self.r[l + 1][I // 2] += res * 0.5 @@ -112,8 +114,8 @@ def smooth(self, l: ti.template(), phase: ti.template()): # phase = red/black Gauss-Seidel phase for I in ti.grouped(self.r[l]): if (I.sum()) & 1 == phase: - self.z[l][I] = (self.r[l][I] + - self.neighbor_sum(self.z[l], I)) / 6.0 + self.z[l][I] = (self.r[l][I] + self.neighbor_sum( + self.z[l], I)) / (2.0 * self.dim) def apply_preconditioner(self): self.z[0].fill(0) @@ -137,11 +139,12 @@ def apply_preconditioner(self): @ti.kernel def paint(self): - kk = self.N_tot * 3 // 8 - for i, j in self.pixels: - ii = int(i * self.N / self.N_gui) + self.N_ext - jj = int(j * self.N / self.N_gui) + self.N_ext - self.pixels[i, j] = self.x[ii, jj, kk] / self.N_tot + if ti.static(self.dim == 3): + kk = self.N_tot * 3 // 8 + for i, j in self.pixels: + ii = int(i * self.N / self.N_gui) + self.N_ext + jj = int(j * self.N / self.N_gui) + self.N_ext + self.pixels[i, j] = self.x[ii, jj, kk] / self.N_tot def run(self): gui = ti.GUI("Multigrid Preconditioned Conjugate Gradients", diff --git a/python/taichi/lang/transformer.py b/python/taichi/lang/transformer.py index 7a6c0bf8efdb8..3a4b6831a3dea 100644 --- a/python/taichi/lang/transformer.py +++ b/python/taichi/lang/transformer.py @@ -274,131 +274,46 @@ def check_loop_var(self, loop_var): "Variable '{}' is already declared in the outer scope and cannot be used as loop variable" .format(loop_var)) - def visit_For(self, node): - if node.orelse: - raise TaichiSyntaxError( - "'else' clause for 'for' not supported in Taichi kernels") - - def is_decorated(iter): - return isinstance(iter, ast.Call) \ - and isinstance(iter.func, ast.Attribute) \ - and isinstance(iter.func.value, ast.Name) \ - and iter.func.value.id == 'ti' and ( - iter.func.attr == 'static' - or iter.func.attr == 'grouped' - or iter.func.attr == 'ndrange') - - decorated = is_decorated(node.iter) - double_decorated = decorated and len(node.iter.args) == 1 \ - and is_decorated(node.iter.args[0]) - is_ndrange_for = 0 - is_static_for = 0 - is_grouped = 0 + @staticmethod + def get_decorator(iter): + if not (isinstance(iter, ast.Call) + and isinstance(iter.func, ast.Attribute) + and isinstance(iter.func.value, ast.Name) + and iter.func.value.id == 'ti' and + (iter.func.attr == 'static' or iter.func.attr == 'grouped' + or iter.func.attr == 'ndrange')): + return '' + return iter.func.attr - if decorated: - attr = node.iter.func - # outer decorator - if attr.attr == 'static': - is_static_for = 1 - elif attr.attr == 'grouped': - is_grouped = 1 - elif attr.attr == 'ndrange': - is_ndrange_for = 1 - else: - raise Exception('Not supported') - if double_decorated: - attr = node.iter.args[0].func - # inner decorator - if attr.attr == 'static': - if is_static_for == 1: - raise TaichiSyntaxError("'ti.static' cannot be nested") - is_static_for = 2 - elif attr.attr == 'grouped': - if is_grouped == 1: - raise TaichiSyntaxError( - "'ti.grouped' cannot be nested") - is_grouped = 2 - elif attr.attr == 'ndrange': - if is_ndrange_for == 1: - raise TaichiSyntaxError( - "'ti.ndrange' cannot be nested") - is_ndrange_for = 2 - else: - raise Exception('Not supported') + @staticmethod + def get_targets(node): + if isinstance(node.target, ast.Name): + return [node.target.id] + else: + assert isinstance(node.target, ast.Tuple) + return [name.id for name in node.target.elts] + + def visit_static_for(self, node): + # for i in ti.static(range(n)) + # for i, j in ti.static(ti.ndrange(n)) + # for I in ti.static(ti.grouped(ti.ndrange(n, m))) + self.generic_visit(node, ['body']) + t = self.parse_stmt('if 1: pass; del a') + t.body[0] = node + target = copy.deepcopy(node.target) + target.ctx = ast.Del() + if isinstance(target, ast.Tuple): + for tar in target.elts: + tar.ctx = ast.Del() + t.body[1].targets = [target] + return t - is_range_for = isinstance(node.iter, ast.Call) and isinstance( - node.iter.func, ast.Name) and node.iter.func.id == 'range' - ast.fix_missing_locations(node) - is_nonstatic_ndrange = is_ndrange_for == 1 or (is_ndrange_for == 2 - and is_grouped == 1) - if not is_nonstatic_ndrange: - self.generic_visit(node, ['body']) - if is_nonstatic_ndrange: - dim = len(node.iter.args) if is_ndrange_for == 1 else len( - node.iter.args[0].args) - template = ''' -if ti.static(1): - __ndrange = 0 - for __ndrange_I in range(0): - __I = __ndrange_I - ''' if is_ndrange_for == 1 else ''' -if ti.static(1): - __ndrange = 0 - {} = ti.Vector([0] * len(__ndrange.dimensions)) - for __ndrange_I in range(0): - __I = __ndrange_I - '''.format(node.target.id) - t = ast.parse(template).body[0] - t.body[0].value = node.iter if is_ndrange_for == 1 \ - else node.iter.args[0] - t_loop = t.body[1] if is_ndrange_for == 1 else t.body[2] - t_loop.iter.args[0] = self.parse_expr( - '__ndrange.acc_dimensions[0]') - targets = node.target - if is_ndrange_for == 1: - if isinstance(targets, ast.Tuple): - targets = [name.id for name in targets.elts] - else: - targets = [targets.id] - targets_tmp = ['__' + name for name in targets] - else: - targets = ['{}[{}]'.format(targets.id, i) for i in range(dim)] - targets_tmp = [ - '__{}_{}'.format(node.target.id, i) for i in range(dim) - ] - loop_body = t_loop.body - for i in range(len(targets)): - if i + 1 < len(targets): - stmt = '{} = __I // __ndrange.acc_dimensions[{}]'.format( - targets_tmp[i], i + 1) - else: - stmt = '{} = __I'.format(targets_tmp[i]) - loop_body.append(self.parse_stmt(stmt)) - stmt = '{} = {} + __ndrange.bounds[{}][0]'.format( - targets[i], targets_tmp[i], i) - loop_body.append(self.parse_stmt(stmt)) - if i + 1 < len(targets): - stmt = '__I = __I - {} * __ndrange.acc_dimensions[{}]'.format( - targets_tmp[i], i + 1) - loop_body.append(self.parse_stmt(stmt)) - loop_body += node.body - - node = ast.copy_location(t, node) - return self.visit(node) # further translate as a range for - elif is_static_for == 1: - t = self.parse_stmt('if 1: pass; del a') - t.body[0] = node - target = copy.deepcopy(node.target) - target.ctx = ast.Del() - if isinstance(target, ast.Tuple): - for tar in target.elts: - tar.ctx = ast.Del() - t.body[1].targets = [target] - return t - elif is_range_for == 1: - loop_var = node.target.id - self.check_loop_var(loop_var) - template = ''' + def visit_range_for(self, node): + # for i in range(n) + self.generic_visit(node, ['body']) + loop_var = node.target.id + self.check_loop_var(loop_var) + template = ''' if 1: {} = ti.Expr(ti.core.make_id_expr('')) ___begin = ti.Expr(0) @@ -407,66 +322,173 @@ def is_decorated(iter): ___end = ti.cast(___end, ti.i32) ti.core.begin_frontend_range_for({}.ptr, ___begin.ptr, ___end.ptr) ti.core.end_frontend_range_for() - '''.format(loop_var, loop_var) - t = ast.parse(template).body[0] + '''.format(loop_var, loop_var) + t = ast.parse(template).body[0] - assert len(node.iter.args) in [1, 2] - if len(node.iter.args) == 2: - bgn = node.iter.args[0] - end = node.iter.args[1] - else: - bgn = self.make_constant(value=0) - end = node.iter.args[0] + assert len(node.iter.args) in [1, 2] + if len(node.iter.args) == 2: + bgn = node.iter.args[0] + end = node.iter.args[1] + else: + bgn = self.make_constant(value=0) + end = node.iter.args[0] - t.body[1].value.args[0] = bgn - t.body[2].value.args[0] = end - t.body = t.body[:6] + node.body + t.body[6:] - t.body.append(self.parse_stmt('del {}'.format(loop_var))) - return ast.copy_location(t, node) - else: # Struct for - assert is_static_for == 0 - assert is_ndrange_for == 0 - if isinstance(node.target, ast.Name): - elts = [node.target] + t.body[1].value.args[0] = bgn + t.body[2].value.args[0] = end + t.body = t.body[:6] + node.body + t.body[6:] + t.body.append(self.parse_stmt('del {}'.format(loop_var))) + return ast.copy_location(t, node) + + def visit_ndrange_for(self, node): + # for i, j in ti.ndrange(n) + template = ''' +if ti.static(1): + __ndrange = 0 + for __ndrange_I in range(0): + __I = __ndrange_I + ''' + t = ast.parse(template).body[0] + t.body[0].value = node.iter + t_loop = t.body[1] + t_loop.iter.args[0] = self.parse_expr('__ndrange.acc_dimensions[0]') + targets = self.get_targets(node) + targets_tmp = ['__' + name for name in targets] + loop_body = t_loop.body + for i in range(len(targets)): + if i + 1 < len(targets): + stmt = '{} = __I // __ndrange.acc_dimensions[{}]'.format( + targets_tmp[i], i + 1) else: - elts = node.target.elts + stmt = '{} = __I'.format(targets_tmp[i]) + loop_body.append(self.parse_stmt(stmt)) + stmt = '{} = {} + __ndrange.bounds[{}][0]'.format( + targets[i], targets_tmp[i], i) + loop_body.append(self.parse_stmt(stmt)) + if i + 1 < len(targets): + stmt = '__I = __I - {} * __ndrange.acc_dimensions[{}]'.format( + targets_tmp[i], i + 1) + loop_body.append(self.parse_stmt(stmt)) + loop_body += node.body + + node = ast.copy_location(t, node) + return self.visit(node) # further translate as a range for + + def visit_grouped_ndrange_for(self, node): + # for I in ti.grouped(ti.ndrange(n, m)) + self.generic_visit(node, ['body']) + target = node.target.id + template = ''' +if ti.static(1): + __ndrange = 0 + {} = ti.expr_init(ti.Vector([0] * len(__ndrange.dimensions))) + ___begin = ti.Expr(0) + ___end = __ndrange.acc_dimensions[0] + ___begin = ti.cast(___begin, ti.i32) + ___end = ti.cast(___end, ti.i32) + __ndrange_I = ti.Expr(ti.core.make_id_expr('')) + ti.core.begin_frontend_range_for(__ndrange_I.ptr, ___begin.ptr, ___end.ptr) + __I = __ndrange_I + for __grouped_I in range(len(__ndrange.dimensions)): + __grouped_I_tmp = 0 + if __grouped_I + 1 < len(__ndrange.dimensions): + __grouped_I_tmp = __I // __ndrange.acc_dimensions[__grouped_I + 1] + else: + __grouped_I_tmp = __I + ti.subscript({}, __grouped_I).assign(__grouped_I_tmp + __ndrange.bounds[__grouped_I][0]) + if __grouped_I + 1 < len(__ndrange.dimensions): + __I = __I - __grouped_I_tmp * __ndrange.acc_dimensions[__grouped_I + 1] + ti.core.end_frontend_range_for() + '''.format(target, target) + t = ast.parse(template).body[0] + t.body[0].value = node.iter.args[0] + cut = len(t.body) - 1 + t.body = t.body[:cut] + node.body + t.body[cut:] + return ast.copy_location(t, node) + + def visit_struct_for(self, node, is_grouped): + # for i, j in x + # for I in ti.grouped(x) + self.generic_visit(node, ['body']) + targets = self.get_targets(node) - for loop_var in elts: - self.check_loop_var(loop_var.id) + for loop_var in targets: + self.check_loop_var(loop_var) - var_decl = ''.join( - ' {} = ti.Expr(ti.core.make_id_expr(""))\n'.format(ind.id) - for ind in elts) # indent: 4 spaces - vars = ', '.join(ind.id for ind in elts) - if is_grouped: - template = ''' + var_decl = ''.join( + ' {} = ti.Expr(ti.core.make_id_expr(""))\n'.format(name) + for name in targets) # indent: 4 spaces + vars = ', '.join(targets) + if is_grouped: + template = ''' if 1: ___loop_var = 0 {} = ti.make_var_vector(size=___loop_var.loop_range().dim()) ___expr_group = ti.make_expr_group({}) ti.core.begin_frontend_struct_for(___expr_group, ___loop_var.loop_range().ptr) ti.core.end_frontend_range_for() - '''.format(vars, vars) - t = ast.parse(template).body[0] - cut = 4 - t.body[0].value = node.iter - t.body = t.body[:cut] + node.body + t.body[cut:] - else: - template = ''' + '''.format(vars, vars) + t = ast.parse(template).body[0] + cut = 4 + t.body[0].value = node.iter + t.body = t.body[:cut] + node.body + t.body[cut:] + else: + template = ''' if 1: {} ___loop_var = 0 ___expr_group = ti.make_expr_group({}) ti.core.begin_frontend_struct_for(___expr_group, ___loop_var.loop_range().ptr) ti.core.end_frontend_range_for() - '''.format(var_decl, vars) - t = ast.parse(template).body[0] - cut = len(elts) + 3 - t.body[cut - 3].value = node.iter - t.body = t.body[:cut] + node.body + t.body[cut:] - for loop_var in reversed(elts): - t.body.append(self.parse_stmt('del {}'.format(loop_var.id))) - return ast.copy_location(t, node) + '''.format(var_decl, vars) + t = ast.parse(template).body[0] + cut = len(targets) + 3 + t.body[cut - 3].value = node.iter + t.body = t.body[:cut] + node.body + t.body[cut:] + for loop_var in reversed(targets): + t.body.append(self.parse_stmt('del {}'.format(loop_var))) + return ast.copy_location(t, node) + + def visit_For(self, node): + if node.orelse: + raise TaichiSyntaxError( + "'else' clause for 'for' not supported in Taichi kernels") + + decorator = self.get_decorator(node.iter) + double_decorator = '' + if decorator != '' and len(node.iter.args) == 1: + double_decorator = self.get_decorator(node.iter.args[0]) + ast.fix_missing_locations(node) + + if decorator == 'static': + if double_decorator == 'static': + raise TaichiSyntaxError("'ti.static' cannot be nested") + if double_decorator == 'grouped' and ( + len(node.iter.args[0].args) != 1 + or self.get_decorator(node.iter.args[0].args[0]) == ''): + raise TaichiSyntaxError( + "Static grouped struct for loop is not allowed. Please use 'ti.static(ti.grouped(ti.ndrange(...)))' instead." + ) + return self.visit_static_for(node) + elif decorator == 'ndrange': + if double_decorator != '': + raise TaichiSyntaxError( + "No decorator is allowed inside 'ti.ndrange") + return self.visit_ndrange_for(node) + elif decorator == 'grouped': + if double_decorator == 'static': + raise TaichiSyntaxError( + "'ti.static' is not allowed inside 'ti.grouped'") + elif double_decorator == 'ndrange': + return self.visit_grouped_ndrange_for(node) + elif double_decorator == 'grouped': + raise TaichiSyntaxError("'ti.grouped' cannot be nested") + else: + return self.visit_struct_for(node, is_grouped=True) + elif isinstance(node.iter, ast.Call) and isinstance( + node.iter.func, ast.Name) and node.iter.func.id == 'range': + return self.visit_range_for(node) + else: # Struct for + return self.visit_struct_for(node, is_grouped=False) @staticmethod def parse_stmt(stmt): diff --git a/tests/python/test_grouped.py b/tests/python/test_grouped.py index 26339ff69888f..c18d1cde04e7c 100644 --- a/tests/python/test_grouped.py +++ b/tests/python/test_grouped.py @@ -81,6 +81,58 @@ def test(): j * 2 if x0 <= i < y0 and x1 <= j < y1 else 0) +@ti.all_archs +def test_static_grouped_ndrange(): + val = ti.var(ti.i32) + + n = 4 + m = 8 + + ti.root.dense(ti.ij, (n, m)).place(val) + + x0 = 2 + y0 = 3 + x1 = 1 + y1 = 6 + + @ti.kernel + def test(): + for I in ti.static(ti.grouped(ti.ndrange((x0, y0), (x1, y1)))): + val[I] = I[0] + I[1] * 2 + + test() + + for i in range(n): + for j in range(m): + assert val[i, j] == (i + + j * 2 if x0 <= i < y0 and x1 <= j < y1 else 0) + + +@ti.all_archs +def test_grouped_ndrange_starred(): + val = ti.var(ti.i32) + + n = 4 + m = 8 + p = 16 + dim = 3 + + ti.root.dense(ti.ijk, (n, m, p)).place(val) + + @ti.kernel + def test(): + for I in ti.grouped(ti.ndrange(*(((0, n), ) * dim))): + val[I] = I[0] + I[1] * 2 + I[2] * 3 + + test() + + for i in range(n): + for j in range(m): + for k in range(p): + assert val[i, j, + k] == (i + j * 2 + k * 3 if j < n and k < n else 0) + + @ti.all_archs def test_grouped_ndrange_0d(): val = ti.var(ti.i32, shape=()) @@ -96,7 +148,7 @@ def test(): @ti.all_archs -def test_grouped_ndrange_0d(): +def test_static_grouped_ndrange_0d(): val = ti.var(ti.i32, shape=()) @ti.kernel diff --git a/tests/python/test_syntax_errors.py b/tests/python/test_syntax_errors.py index 92fa0ee1bb5ad..66bae69375a2e 100644 --- a/tests/python/test_syntax_errors.py +++ b/tests/python/test_syntax_errors.py @@ -210,3 +210,17 @@ def func(): pass func() + + +@ti.must_throw(ti.TaichiSyntaxError) +def test_static_grouped_struct_for(): + val = ti.var(ti.i32) + + ti.root.dense(ti.ij, (1, 1)).place(val) + + @ti.kernel + def test(): + for I in ti.static(ti.grouped(val)): + pass + + test()