Skip to content

Commit

Permalink
[Lang] Support "break" and "continue" in static-fors (#1496)
Browse files Browse the repository at this point in the history
* mv ScopeGuard.{t,local_scope}

* control_scope (tmp)

* fix

* continue, test

* fix test

* fix while indent

* [skip ci] enforce code format

* fix test_materialization_after_kernel

* [skip ci] Update python/taichi/lang/transformer.py

Co-authored-by: Taichi Gardener <taichigardener@gmail.com>
  • Loading branch information
archibate and taichi-gardener authored Jul 16, 2020
1 parent 6451486 commit 102d9b8
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 45 deletions.
114 changes: 69 additions & 45 deletions python/taichi/lang/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,21 @@


class ScopeGuard:
def __init__(self, t, stmt_block=None):
self.t = t
def __init__(self, scopes, stmt_block=None):
self.scopes = scopes
self.stmt_block = stmt_block

def __enter__(self):
self.t.local_scopes.append([])
self.scopes.append([])

def __exit__(self, exc_type, exc_val, exc_tb):
local = self.t.local_scopes[-1]
local = self.scopes[-1]
if self.stmt_block is not None:
for var in reversed(local):
stmt = ASTTransformer.parse_stmt('del var')
stmt.targets[0].id = var
self.stmt_block.append(stmt)
self.t.local_scopes = self.t.local_scopes[:-1]
self.scopes.pop()


# Single-pass transform
Expand All @@ -32,18 +32,27 @@ def __init__(self,
arg_features=None):
super().__init__()
self.local_scopes = []
self.control_scopes = []
self.excluded_parameters = excluded_paremeters
self.is_kernel = is_kernel
self.func = func
self.arg_features = arg_features
self.returns = None

# e.g.: FunctionDef, Module, Global
def variable_scope(self, *args):
return ScopeGuard(self, *args)
return ScopeGuard(self.local_scopes, *args)

# e.g.: For, While
def control_scope(self):
return ScopeGuard(self.control_scopes)

def current_scope(self):
return self.local_scopes[-1]

def current_control_scope(self):
return self.control_scopes[-1]

def var_declared(self, name):
for s in self.local_scopes:
if name in s:
Expand Down Expand Up @@ -207,7 +216,10 @@ def visit_While(self, node):
raise TaichiSyntaxError(
"'else' clause for 'while' not supported in Taichi kernels")

template = '''
with self.control_scope():
self.current_control_scope().append('while')

template = '''
if 1:
ti.core.begin_frontend_while(ti.Expr(1).ptr)
__while_cond = 0
Expand All @@ -217,13 +229,13 @@ def visit_While(self, node):
break
ti.core.pop_scope()
'''
cond = node.test
t = ast.parse(template).body[0]
t.body[1].value = cond
t.body = t.body[:3] + node.body + t.body[3:]
cond = node.test
t = ast.parse(template).body[0]
t.body[1].value = cond
t.body = t.body[:3] + node.body + t.body[3:]

self.generic_visit(t, ['body'])
return ast.copy_location(t, node)
self.generic_visit(t, ['body'])
return ast.copy_location(t, node)

def visit_block(self, list_stmt):
for i, l in enumerate(list_stmt):
Expand Down Expand Up @@ -290,6 +302,8 @@ def visit_static_for(self, node, is_grouped):
# for i in ti.static(range(n))
# for i, j in ti.static(ti.ndrange(n))
# for I in ti.static(ti.grouped(ti.ndrange(n, m)))

self.current_control_scope().append('static')
self.generic_visit(node, ['body'])
if is_grouped:
assert len(node.iter.args[0].args) == 1
Expand Down Expand Up @@ -462,36 +476,40 @@ def visit_For(self, node):
raise TaichiSyntaxError(
"'else' clause for 'for' not supported in Taichi kernels")

decorator = self.get_decorator(node.iter)
double_decorator = ''
if decorator != '' and len(node.iter.args) == 1:
double_decorator = self.get_decorator(node.iter.args[0])
ast.fix_missing_locations(node)

if decorator == 'static':
if double_decorator == 'static':
raise TaichiSyntaxError("'ti.static' cannot be nested")
return self.visit_static_for(node, double_decorator == 'grouped')
elif decorator == 'ndrange':
if double_decorator != '':
raise TaichiSyntaxError(
"No decorator is allowed inside 'ti.ndrange")
return self.visit_ndrange_for(node)
elif decorator == 'grouped':
if double_decorator == 'static':
raise TaichiSyntaxError(
"'ti.static' is not allowed inside 'ti.grouped'")
elif double_decorator == 'ndrange':
return self.visit_grouped_ndrange_for(node)
elif double_decorator == 'grouped':
raise TaichiSyntaxError("'ti.grouped' cannot be nested")
else:
return self.visit_struct_for(node, is_grouped=True)
elif isinstance(node.iter, ast.Call) and isinstance(
node.iter.func, ast.Name) and node.iter.func.id == 'range':
return self.visit_range_for(node)
else: # Struct for
return self.visit_struct_for(node, is_grouped=False)
with self.control_scope():
self.current_control_scope().append('for')

decorator = self.get_decorator(node.iter)
double_decorator = ''
if decorator != '' and len(node.iter.args) == 1:
double_decorator = self.get_decorator(node.iter.args[0])
ast.fix_missing_locations(node)

if decorator == 'static':
if double_decorator == 'static':
raise TaichiSyntaxError("'ti.static' cannot be nested")
return self.visit_static_for(node,
double_decorator == 'grouped')
elif decorator == 'ndrange':
if double_decorator != '':
raise TaichiSyntaxError(
"No decorator is allowed inside 'ti.ndrange")
return self.visit_ndrange_for(node)
elif decorator == 'grouped':
if double_decorator == 'static':
raise TaichiSyntaxError(
"'ti.static' is not allowed inside 'ti.grouped'")
elif double_decorator == 'ndrange':
return self.visit_grouped_ndrange_for(node)
elif double_decorator == 'grouped':
raise TaichiSyntaxError("'ti.grouped' cannot be nested")
else:
return self.visit_struct_for(node, is_grouped=True)
elif isinstance(node.iter, ast.Call) and isinstance(
node.iter.func, ast.Name) and node.iter.func.id == 'range':
return self.visit_range_for(node)
else: # Struct for
return self.visit_struct_for(node, is_grouped=False)

@staticmethod
def parse_stmt(stmt):
Expand Down Expand Up @@ -531,10 +549,16 @@ def visit_IfExp(self, node):
return ast.copy_location(call, node)

def visit_Break(self, node):
return self.parse_stmt('ti.core.insert_break_stmt()')
if 'static' in self.current_control_scope():
return node
else:
return self.parse_stmt('ti.core.insert_break_stmt()')

def visit_Continue(self, node):
return self.parse_stmt('ti.core.insert_continue_stmt()')
if 'static' in self.current_control_scope():
return node
else:
return self.parse_stmt('ti.core.insert_continue_stmt()')

def visit_Call(self, node):
if not (isinstance(node.func, ast.Attribute)
Expand Down
33 changes: 33 additions & 0 deletions tests/python/test_static.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import taichi as ti
import numpy as np


@ti.all_archs
Expand Down Expand Up @@ -51,3 +52,35 @@ def fill():
for i in range(3):
for j in range(3):
assert x[i, j][i, j] == i + j * 2


@ti.host_arch_only
def test_static_break():
x = ti.var(ti.i32, 5)

@ti.kernel
def func():
for i in ti.static(range(5)):
x[i] = 1
if ti.static(i == 2):
break

func()

assert np.allclose(x.to_numpy(), np.array([1, 1, 1, 0, 0]))


@ti.host_arch_only
def test_static_continue():
x = ti.var(ti.i32, 5)

@ti.kernel
def func():
for i in ti.static(range(5)):
if ti.static(i == 2):
continue
x[i] = 1

func()

assert np.allclose(x.to_numpy(), np.array([1, 1, 0, 1, 1]))

0 comments on commit 102d9b8

Please sign in to comment.