Skip to content

Commit

Permalink
Fixed more review comments
Browse files Browse the repository at this point in the history
Change-Id: If6133dd822f33a8d32f3ddd8b8ce22b92490694e
  • Loading branch information
mbaret committed Jun 9, 2021
1 parent 53b99b1 commit 2b61414
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
8 changes: 4 additions & 4 deletions python/tvm/tir/transform/inject_rolling_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def InjectRollingBuffer():
buffer_to_attrs = defaultdict(list)
rolling_buffers = set()
rolling_buffer_to_info = dict()
iter_vars = list()
for_loops = list()
hoist_buffer_to_for = defaultdict(list)

RollingBufferInfo = namedtuple(
Expand All @@ -56,7 +56,7 @@ def InjectRollingBuffer():
def _pre_visit(stmt):
if isinstance(stmt, tvm.tir.For):
# Manage the stack of iter_vars
iter_vars.append(stmt)
for_loops.append(stmt)

elif isinstance(stmt, tvm.tir.AttrStmt):
if isinstance(stmt.node, tvm.tir.Buffer):
Expand Down Expand Up @@ -118,7 +118,7 @@ def _pre_visit(stmt):
# to be the rolling axis
roll_iter_var = None
roll_axis = -1
for loop in iter_vars:
for loop in for_loops:
iter_var = loop.loop_var
if iter_var in bound_iter_vars:
roll_iter_var = iter_var
Expand Down Expand Up @@ -148,7 +148,7 @@ def _pre_visit(stmt):
def _post_visit(stmt):
if isinstance(stmt, tvm.tir.For):
# Manage the stack of iter_vars
iter_vars.pop()
for_loops.pop()
# If the loop corresponds to an iter_var that needs a BufferRealize
# hoisting to its scope, perform the hoisting
if stmt.loop_var in hoist_buffer_to_for:
Expand Down
3 changes: 1 addition & 2 deletions src/te/operation/compute_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -484,9 +484,8 @@ ComputeLoopNest ComputeLoopNest::Create(const BaseComputeOpNode* self, const Sta
}
ret.init_nest = MakeLoopNest(stage, dom_map, begin_loop, true, skip_iter, &(ret.init_vmap),
debug_keep_trivial_loop);
bool skip_ivar_domain = !stage->rolling_buffer;
ret.init_predicates =
MakeBoundCheck(stage, dom_map, ret.init_vmap, skip_ivar_domain, skip_iter);
MakeBoundCheck(stage, dom_map, ret.init_vmap, !stage->rolling_buffer, skip_iter);
for (auto& e : ret.init_predicates) {
e = likely(e);
}
Expand Down

0 comments on commit 2b61414

Please sign in to comment.