Skip to content

Commit

Permalink
[async] Cache demote_activation (#1889)
Browse files Browse the repository at this point in the history
* [async] Cache demote_activation

* [skip ci] Apply suggestions from code review

Co-authored-by: Yuanming Hu <yuanming-hu@users.noreply.github.com>

* [skip ci] enforce code format

* [skip ci] Edit comment

Co-authored-by: Yuanming Hu <yuanming-hu@users.noreply.github.com>
Co-authored-by: Taichi Gardener <taichigardener@gmail.com>
  • Loading branch information
3 people authored Sep 25, 2020
1 parent e8a5686 commit 25b9a10
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 83 deletions.
90 changes: 90 additions & 0 deletions taichi/program/async_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,96 @@ IRHandle IRBank::fuse(IRHandle handle_a, IRHandle handle_b, Kernel *kernel) {
return result;
}

// TODO: make this an IR pass
class ConstExprPropagation {
public:
static std::unordered_set<Stmt *> run(
Block *block,
const std::function<bool(Stmt *)> &is_const_seed) {
std::unordered_set<Stmt *> const_stmts;

auto is_const = [&](Stmt *stmt) {
if (is_const_seed(stmt)) {
return true;
} else {
return const_stmts.find(stmt) != const_stmts.end();
}
};

for (auto &s : block->statements) {
if (is_const(s.get())) {
const_stmts.insert(s.get());
} else if (auto binary = s->cast<BinaryOpStmt>()) {
if (is_const(binary->lhs) && is_const(binary->rhs)) {
const_stmts.insert(s.get());
}
} else if (auto unary = s->cast<UnaryOpStmt>()) {
if (is_const(unary->operand)) {
const_stmts.insert(s.get());
}
} else {
// TODO: ...
}
}

return const_stmts;
}
};

IRHandle IRBank::demote_activation(IRHandle handle) {
auto &result = demote_activation_bank_[handle];
if (!result.empty()) {
return result;
}

std::unique_ptr<IRNode> new_ir = handle.clone();

OffloadedStmt *offload = new_ir->as<OffloadedStmt>();
Block *body = offload->body.get();

auto snode = offload->snode;
TI_ASSERT(snode != nullptr);

// TODO: for now we only deal with the top level. Is there an easy way to
// extend this part?
auto consts = ConstExprPropagation::run(body, [](Stmt *stmt) {
if (stmt->is<ConstStmt>()) {
return true;
} else if (stmt->is<LoopIndexStmt>())
return true;
return false;
});

bool demoted = false;
for (int k = 0; k < (int)body->statements.size(); k++) {
Stmt *stmt = body->statements[k].get();
if (auto ptr = stmt->cast<GlobalPtrStmt>(); ptr && ptr->activate) {
bool can_demote = true;
// TODO: test input mask?
for (auto ind : ptr->indices) {
if (consts.find(ind) == consts.end()) {
// non-constant index
can_demote = false;
}
}
if (can_demote) {
ptr->activate = false;
demoted = true;
}
}
}

if (!demoted) {
// Nothing demoted. Simply delete new_ir when this function returns.
result = handle;
return result;
}

result = IRHandle(new_ir.get(), get_hash(new_ir.get()));
insert(std::move(new_ir), result.hash());
return result;
}

ParallelExecutor::ParallelExecutor(int num_threads)
: num_threads(num_threads),
status(ExecutorStatus::uninitialized),
Expand Down
3 changes: 3 additions & 0 deletions taichi/program/async_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class IRBank {
// Fuse handle_b into handle_a
IRHandle fuse(IRHandle handle_a, IRHandle handle_b, Kernel *kernel);

IRHandle demote_activation(IRHandle handle);

std::unordered_map<IRHandle, TaskMeta> meta_bank_;
std::unordered_map<IRHandle, TaskFusionMeta> fusion_meta_bank_;

Expand All @@ -39,6 +41,7 @@ class IRBank {
std::unordered_map<IRHandle, std::unique_ptr<IRNode>> ir_bank_;
std::vector<std::unique_ptr<IRNode>> trash_bin; // prevent IR from deleted
std::unordered_map<std::pair<IRHandle, IRHandle>, IRHandle> fuse_bank_;
std::unordered_map<IRHandle, IRHandle> demote_activation_bank_;
};

class ParallelExecutor {
Expand Down
4 changes: 4 additions & 0 deletions taichi/program/async_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ class IRHandle {
return hash_ == other_ir_handle.hash_;
}

bool operator!=(const IRHandle &other_ir_handle) const {
return !(*this == other_ir_handle);
}

bool operator<(const IRHandle &other_ir_handle) const {
return hash_ < other_ir_handle.hash_;
}
Expand Down
94 changes: 11 additions & 83 deletions taichi/program/state_flow_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -935,42 +935,6 @@ void StateFlowGraph::verify() {
topo_sort_nodes();
}

// TODO: make this an IR pass
class ConstExprPropagation {
public:
static std::unordered_set<Stmt *> run(
Block *block,
const std::function<bool(Stmt *)> &is_const_seed) {
std::unordered_set<Stmt *> const_stmts;

auto is_const = [&](Stmt *stmt) {
if (is_const_seed(stmt)) {
return true;
} else {
return const_stmts.find(stmt) != const_stmts.end();
}
};

for (auto &s : block->statements) {
if (is_const(s.get())) {
const_stmts.insert(s.get());
} else if (auto binary = s->cast<BinaryOpStmt>()) {
if (is_const(binary->lhs) && is_const(binary->rhs)) {
const_stmts.insert(s.get());
}
} else if (auto unary = s->cast<UnaryOpStmt>()) {
if (is_const(unary->operand)) {
const_stmts.insert(s.get());
}
} else {
// TODO: ...
}
}

return const_stmts;
}
};

bool StateFlowGraph::demote_activation() {
bool modified = false;

Expand Down Expand Up @@ -1001,55 +965,19 @@ bool StateFlowGraph::demote_activation() {
if (nodes.size() <= 1)
continue;

auto snode = nodes[0]->meta->snode;

auto list_state = AsyncState(snode, AsyncState::Type::list);

TI_ASSERT(snode != nullptr);

std::unique_ptr<IRNode> new_ir = nodes[0]->rec.ir_handle.clone();

OffloadedStmt *offload = new_ir->as<OffloadedStmt>();
Block *body = offload->body.get();

// TODO: for now we only deal with the top level. Is there an easy way to
// extend this part?
auto consts = ConstExprPropagation::run(body, [](Stmt *stmt) {
if (stmt->is<ConstStmt>()) {
return true;
} else if (stmt->is<LoopIndexStmt>())
return true;
return false;
});

bool demoted = false;
for (int k = 0; k < (int)body->statements.size(); k++) {
Stmt *stmt = body->statements[k].get();
if (auto ptr = stmt->cast<GlobalPtrStmt>(); ptr && ptr->activate) {
bool can_demote = true;
// TODO: test input mask?
for (auto ind : ptr->indices) {
if (consts.find(ind) == consts.end()) {
// non-constant index
can_demote = false;
}
}
if (can_demote) {
modified = true;
ptr->activate = false;
demoted = true;
}
}
}
// TODO: cache this part
auto new_handle = IRHandle(new_ir.get(), ir_bank_->get_hash(new_ir.get()));
ir_bank_->insert(std::move(new_ir), new_handle.hash());
auto new_meta = get_task_meta(ir_bank_, nodes[0]->rec);
if (demoted) {
for (int j = 1; j < (int)nodes.size(); j++) {
auto new_handle = ir_bank_->demote_activation(nodes[0]->rec.ir_handle);
if (new_handle != nodes[0]->rec.ir_handle) {
modified = true;
nodes[1]->rec.ir_handle = new_handle;
nodes[1]->meta = get_task_meta(ir_bank_, nodes[1]->rec);
for (int j = 2; j < (int)nodes.size(); j++) {
nodes[j]->rec.ir_handle = new_handle;
nodes[j]->meta = new_meta;
nodes[j]->meta = nodes[1]->meta;
}
// For every "demote_activation" call, we only optimize for a single key
// in std::map<std::pair<IRHandle, Node *>, std::vector<Node *>> tasks
// since the graph probably needs to be rebuild after demoting
// part of the tasks.
break;
}
}
Expand Down

0 comments on commit 25b9a10

Please sign in to comment.