Skip to content

Commit

Permalink
[async] Avoid unnecessary list generations and activations (#913)
Browse files Browse the repository at this point in the history
Co-authored-by: Taichi Gardener <taichigardener@gmail.com>
  • Loading branch information
yuanming-hu and taichi-gardener authored May 3, 2020
1 parent 6b60c13 commit 92b6dfc
Show file tree
Hide file tree
Showing 6 changed files with 218 additions and 5 deletions.
4 changes: 4 additions & 0 deletions examples/mgpcg_advanced.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import time
import taichi as ti

real = ti.f32
Expand Down Expand Up @@ -210,5 +211,8 @@ def run(self):


solver = MGPCG()
t = time.time()
solver.run()
print(f'Solver time: {time.time() - t:.3f} s')
ti.core.print_profile_info()
ti.core.print_stat()
17 changes: 17 additions & 0 deletions misc/test_async_weaken_access.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import taichi as ti

ti.init()

x = ti.var(ti.i32)
y = ti.var(ti.i32)

ti.root.pointer(ti.ij, 4).dense(ti.ij, 8).place(x, y)


@ti.kernel
def copy():
for i, j in y:
x[i, j] = y[i, j]


copy()
113 changes: 111 additions & 2 deletions taichi/program/async_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "taichi/program/program.h"
#include "taichi/backends/cpu/codegen_cpu.h"
#include "taichi/common/testing.h"
#include "taichi/util/statistics.h"

TLANG_NAMESPACE_BEGIN

Expand Down Expand Up @@ -60,6 +61,23 @@ void ExecutionQueue::enqueue(KernelLaunchRecord ker) {
func = compiled_func[h];
break;
}
stat.add("launched_kernels", 1.0);
auto task_type = ker.stmt->task_type;
if (task_type == OffloadedStmt::TaskType::listgen) {
stat.add("launched_kernels_list_op", 1.0);
stat.add("launched_kernels_list_gen", 1.0);
} else if (task_type == OffloadedStmt::TaskType::clear_list) {
stat.add("launched_kernels_list_op", 1.0);
stat.add("launched_kernels_list_clear", 1.0);
} else if (task_type == OffloadedStmt::TaskType::range_for) {
stat.add("launched_kernels_compute", 1.0);
stat.add("launched_kernels_range_for", 1.0);
} else if (task_type == OffloadedStmt::TaskType::struct_for) {
stat.add("launched_kernels_compute", 1.0);
stat.add("launched_kernels_struct_for", 1.0);
} else if (task_type == OffloadedStmt::TaskType::gc) {
stat.add("launched_kernels_garbage_collect", 1.0);
}
auto context = ker.context;
func(context);
});
Expand All @@ -82,17 +100,108 @@ void AsyncEngine::launch(Kernel *kernel) {
auto &offloads = block->statements;
for (std::size_t i = 0; i < offloads.size(); i++) {
auto offload = offloads[i]->as<OffloadedStmt>();
task_queue.emplace_back(kernel->program.get_context(), kernel, offload);
KernelLaunchRecord rec(kernel->program.get_context(), kernel, offload);
enqueue(rec);
}
optimize();
}

void AsyncEngine::enqueue(KernelLaunchRecord t) {
using namespace irpass::analysis;

task_queue.push_back(t);

auto &meta = metas[t.h];
// TODO: this is an abuse since it gathers nothing...
gather_statements(t.stmt, [&](Stmt *stmt) {
if (auto global_ptr = stmt->cast<GlobalPtrStmt>()) {
for (auto &snode : global_ptr->snodes.data) {
meta.input_snodes.insert(snode);
}
}
if (auto global_load = stmt->cast<GlobalLoadStmt>()) {
if (auto ptr = global_load->ptr->cast<GlobalPtrStmt>()) {
for (auto &snode : ptr->snodes.data) {
meta.input_snodes.insert(snode);
}
}
}
if (auto global_store = stmt->cast<GlobalStoreStmt>()) {
if (auto ptr = global_store->ptr->cast<GlobalPtrStmt>()) {
for (auto &snode : ptr->snodes.data) {
meta.output_snodes.insert(snode);
}
}
}
if (auto global_atomic = stmt->cast<AtomicOpStmt>()) {
if (auto ptr = global_atomic->dest->cast<GlobalPtrStmt>()) {
for (auto &snode : ptr->snodes.data) {
meta.input_snodes.insert(snode);
meta.output_snodes.insert(snode);
}
}
}

if (auto ptr = stmt->cast<GlobalPtrStmt>()) {
if (ptr->activate) {
for (auto &snode : ptr->snodes.data) {
meta.activation_snodes.insert(snode);
// fmt::print(" **** act {}\n", snode->get_node_type_name_hinted());
}
}
}
return false;
});
}

void AsyncEngine::synchronize() {
optimize();
while (!task_queue.empty()) {
queue.enqueue(task_queue.front());
task_queue.pop_front();
}
queue.synchronize();
}

bool AsyncEngine::optimize() {
// TODO: improve...
bool modified = false;
std::unordered_map<SNode *, bool> list_dirty;
auto new_task_queue = std::deque<KernelLaunchRecord>();
for (int i = 0; i < task_queue.size(); i++) {
// Try to eliminate unused listgens
auto t = task_queue[i];
auto meta = metas[t.h];
auto offload = t.stmt;
bool keep = true;
if (offload->task_type == OffloadedStmt::TaskType::listgen) {
// keep
} else if (offload->task_type == OffloadedStmt::TaskType::clear_list) {
TI_ASSERT(task_queue[i + 1].stmt->task_type ==
OffloadedStmt::TaskType::listgen);
auto snode = offload->snode;
if (list_dirty.find(snode) != list_dirty.end() && !list_dirty[snode]) {
keep = false; // safe to remove
modified = true;
i++; // skip the following list gen as well
continue;
}
list_dirty[snode] = false;
} else {
for (auto snode : meta.activation_snodes) {
while (snode && snode->type != SNodeType::root) {
list_dirty[snode] = true;
snode = snode->parent;
}
}
}
if (keep) {
new_task_queue.push_back(t);
} else {
modified = true;
}
}
task_queue = std::move(new_task_queue);
return modified;
}

TLANG_NAMESPACE_END
12 changes: 10 additions & 2 deletions taichi/program/async_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,22 +148,30 @@ class AsyncEngine {
public:
// TODO: state machine

struct TaskMeta {
std::unordered_set<SNode *> input_snodes, output_snodes;
std::unordered_set<SNode *> activation_snodes;
};

std::unordered_map<std::uint64_t, TaskMeta> metas;

ExecutionQueue queue;

std::deque<KernelLaunchRecord> task_queue;

AsyncEngine() {
}

void optimize() {
}
bool optimize(); // return true when modified

void clear_cache() {
queue.clear_cache();
}

void launch(Kernel *kernel);

void enqueue(KernelLaunchRecord t);

void synchronize();
};

Expand Down
6 changes: 6 additions & 0 deletions taichi/transforms/compile_to_offloads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ void compile_to_offloads(IRNode *ir,
print("Offloaded");
irpass::analysis::verify(ir);

if (!lower_global_access) {
irpass::flag_access(ir);
print("Access flagged after offloading");
irpass::analysis::verify(ir);
}

irpass::extract_constant(ir);
print("Constant extracted II");

Expand Down
71 changes: 70 additions & 1 deletion taichi/transforms/flag_access.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,79 @@ class FlagAccess : public IRVisitor {
}
};

// For struct fors, weaken accesses on variables currently being looped over
// E.g.
// for i in x:
// x[i] = 0
// Here although we are writing to x[i], but i will only loop over active
// elements of x. So we don't need one more activation. Note the indices of x
// accesses must be loop indices for this optimization to be correct.

class WeakenAccess : public BasicStmtVisitor {
public:
using BasicStmtVisitor::visit;

WeakenAccess(IRNode *node) {
allow_undefined_visitor = true;
invoke_default_visitor = false;
node->accept(this);
}

void visit(Block *stmt_list) { // block itself has no id
for (auto &stmt : stmt_list->statements) {
stmt->accept(this);
}
}

void visit(OffloadedStmt *stmt) {
current_offload = stmt;
if (stmt->body)
stmt->body->accept(this);
current_offload = nullptr;
}

void visit(GlobalPtrStmt *stmt) {
if (stmt->activate) {
if (current_offload &&
current_offload->task_type == OffloadedStmt::TaskType::struct_for) {
bool same_as_loop_snode = true;
for (auto snode : stmt->snodes.data) {
if (snode->type == SNodeType::place) {
snode = snode->parent;
}
if (snode != current_offload->snode) {
same_as_loop_snode = false;
}
if (stmt->indices.size() ==
current_offload->snode->num_active_indices)
for (int i = 0; i < current_offload->snode->num_active_indices;
i++) {
auto ind = stmt->indices[i];
// TODO: vectorized cases?
if (auto loop_var = ind->cast<LoopIndexStmt>()) {
if (loop_var->index != i) {
same_as_loop_snode = false;
}
} else {
same_as_loop_snode = false;
}
}
}
if (same_as_loop_snode)
stmt->activate = false;
}
}
}

private:
OffloadedStmt *current_offload;
};

namespace irpass {

void flag_access(IRNode *root) {
FlagAccess instance(root);
FlagAccess flag_access(root);
WeakenAccess weaken_access(root);
}

} // namespace irpass
Expand Down

0 comments on commit 92b6dfc

Please sign in to comment.