Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[async] Avoid unnecessary list generations and activations #913

Merged
merged 8 commits into from
May 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/mgpcg_advanced.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import time
import taichi as ti

real = ti.f32
Expand Down Expand Up @@ -210,5 +211,8 @@ def run(self):


solver = MGPCG()
t = time.time()
solver.run()
print(f'Solver time: {time.time() - t:.3f} s')
ti.core.print_profile_info()
ti.core.print_stat()
17 changes: 17 additions & 0 deletions misc/test_async_weaken_access.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import taichi as ti

ti.init()

x = ti.var(ti.i32)
y = ti.var(ti.i32)

ti.root.pointer(ti.ij, 4).dense(ti.ij, 8).place(x, y)


@ti.kernel
def copy():
for i, j in y:
x[i, j] = y[i, j]


copy()
113 changes: 111 additions & 2 deletions taichi/program/async_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "taichi/program/program.h"
#include "taichi/backends/cpu/codegen_cpu.h"
#include "taichi/common/testing.h"
#include "taichi/util/statistics.h"

TLANG_NAMESPACE_BEGIN

Expand Down Expand Up @@ -60,6 +61,23 @@ void ExecutionQueue::enqueue(KernelLaunchRecord ker) {
func = compiled_func[h];
break;
}
stat.add("launched_kernels", 1.0);
auto task_type = ker.stmt->task_type;
if (task_type == OffloadedStmt::TaskType::listgen) {
stat.add("launched_kernels_list_op", 1.0);
stat.add("launched_kernels_list_gen", 1.0);
} else if (task_type == OffloadedStmt::TaskType::clear_list) {
stat.add("launched_kernels_list_op", 1.0);
stat.add("launched_kernels_list_clear", 1.0);
} else if (task_type == OffloadedStmt::TaskType::range_for) {
stat.add("launched_kernels_compute", 1.0);
stat.add("launched_kernels_range_for", 1.0);
} else if (task_type == OffloadedStmt::TaskType::struct_for) {
stat.add("launched_kernels_compute", 1.0);
stat.add("launched_kernels_struct_for", 1.0);
} else if (task_type == OffloadedStmt::TaskType::gc) {
stat.add("launched_kernels_garbage_collect", 1.0);
}
auto context = ker.context;
func(context);
});
Expand All @@ -82,17 +100,108 @@ void AsyncEngine::launch(Kernel *kernel) {
auto &offloads = block->statements;
for (std::size_t i = 0; i < offloads.size(); i++) {
auto offload = offloads[i]->as<OffloadedStmt>();
task_queue.emplace_back(kernel->program.get_context(), kernel, offload);
KernelLaunchRecord rec(kernel->program.get_context(), kernel, offload);
enqueue(rec);
}
optimize();
}

void AsyncEngine::enqueue(KernelLaunchRecord t) {
using namespace irpass::analysis;

task_queue.push_back(t);

auto &meta = metas[t.h];
// TODO: this is an abuse since it gathers nothing...
gather_statements(t.stmt, [&](Stmt *stmt) {
if (auto global_ptr = stmt->cast<GlobalPtrStmt>()) {
for (auto &snode : global_ptr->snodes.data) {
meta.input_snodes.insert(snode);
}
}
if (auto global_load = stmt->cast<GlobalLoadStmt>()) {
if (auto ptr = global_load->ptr->cast<GlobalPtrStmt>()) {
for (auto &snode : ptr->snodes.data) {
meta.input_snodes.insert(snode);
}
}
}
if (auto global_store = stmt->cast<GlobalStoreStmt>()) {
if (auto ptr = global_store->ptr->cast<GlobalPtrStmt>()) {
for (auto &snode : ptr->snodes.data) {
meta.output_snodes.insert(snode);
}
}
}
if (auto global_atomic = stmt->cast<AtomicOpStmt>()) {
if (auto ptr = global_atomic->dest->cast<GlobalPtrStmt>()) {
for (auto &snode : ptr->snodes.data) {
meta.input_snodes.insert(snode);
meta.output_snodes.insert(snode);
}
}
}

if (auto ptr = stmt->cast<GlobalPtrStmt>()) {
if (ptr->activate) {
for (auto &snode : ptr->snodes.data) {
meta.activation_snodes.insert(snode);
// fmt::print(" **** act {}\n", snode->get_node_type_name_hinted());
}
}
}
return false;
});
}

void AsyncEngine::synchronize() {
optimize();
while (!task_queue.empty()) {
queue.enqueue(task_queue.front());
task_queue.pop_front();
}
queue.synchronize();
}

bool AsyncEngine::optimize() {
// TODO: improve...
bool modified = false;
std::unordered_map<SNode *, bool> list_dirty;
auto new_task_queue = std::deque<KernelLaunchRecord>();
for (int i = 0; i < task_queue.size(); i++) {
// Try to eliminate unused listgens
auto t = task_queue[i];
auto meta = metas[t.h];
auto offload = t.stmt;
bool keep = true;
if (offload->task_type == OffloadedStmt::TaskType::listgen) {
// keep
} else if (offload->task_type == OffloadedStmt::TaskType::clear_list) {
TI_ASSERT(task_queue[i + 1].stmt->task_type ==
OffloadedStmt::TaskType::listgen);
auto snode = offload->snode;
if (list_dirty.find(snode) != list_dirty.end() && !list_dirty[snode]) {
keep = false; // safe to remove
modified = true;
i++; // skip the following list gen as well
continue;
}
list_dirty[snode] = false;
} else {
for (auto snode : meta.activation_snodes) {
while (snode && snode->type != SNodeType::root) {
list_dirty[snode] = true;
snode = snode->parent;
}
}
}
if (keep) {
new_task_queue.push_back(t);
} else {
modified = true;
}
}
task_queue = std::move(new_task_queue);
return modified;
}

TLANG_NAMESPACE_END
12 changes: 10 additions & 2 deletions taichi/program/async_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,22 +148,30 @@ class AsyncEngine {
public:
// TODO: state machine

struct TaskMeta {
std::unordered_set<SNode *> input_snodes, output_snodes;
std::unordered_set<SNode *> activation_snodes;
};

std::unordered_map<std::uint64_t, TaskMeta> metas;

ExecutionQueue queue;

std::deque<KernelLaunchRecord> task_queue;

AsyncEngine() {
}

void optimize() {
}
bool optimize(); // return true when modified

void clear_cache() {
queue.clear_cache();
}

void launch(Kernel *kernel);

void enqueue(KernelLaunchRecord t);

void synchronize();
};

Expand Down
6 changes: 6 additions & 0 deletions taichi/transforms/compile_to_offloads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ void compile_to_offloads(IRNode *ir,
print("Offloaded");
irpass::analysis::verify(ir);

if (!lower_global_access) {
irpass::flag_access(ir);
print("Access flagged after offloading");
irpass::analysis::verify(ir);
}

irpass::extract_constant(ir);
print("Constant extracted II");

Expand Down
71 changes: 70 additions & 1 deletion taichi/transforms/flag_access.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,79 @@ class FlagAccess : public IRVisitor {
}
};

// For struct fors, weaken accesses on variables currently being looped over
// E.g.
// for i in x:
// x[i] = 0
// Here although we are writing to x[i], but i will only loop over active
// elements of x. So we don't need one more activation. Note the indices of x
// accesses must be loop indices for this optimization to be correct.

class WeakenAccess : public BasicStmtVisitor {
public:
using BasicStmtVisitor::visit;

WeakenAccess(IRNode *node) {
allow_undefined_visitor = true;
invoke_default_visitor = false;
node->accept(this);
}

void visit(Block *stmt_list) { // block itself has no id
for (auto &stmt : stmt_list->statements) {
stmt->accept(this);
}
}

void visit(OffloadedStmt *stmt) {
current_offload = stmt;
if (stmt->body)
stmt->body->accept(this);
current_offload = nullptr;
}

void visit(GlobalPtrStmt *stmt) {
if (stmt->activate) {
if (current_offload &&
current_offload->task_type == OffloadedStmt::TaskType::struct_for) {
bool same_as_loop_snode = true;
for (auto snode : stmt->snodes.data) {
if (snode->type == SNodeType::place) {
snode = snode->parent;
}
if (snode != current_offload->snode) {
same_as_loop_snode = false;
}
if (stmt->indices.size() ==
current_offload->snode->num_active_indices)
for (int i = 0; i < current_offload->snode->num_active_indices;
i++) {
auto ind = stmt->indices[i];
// TODO: vectorized cases?
if (auto loop_var = ind->cast<LoopIndexStmt>()) {
if (loop_var->index != i) {
same_as_loop_snode = false;
}
} else {
same_as_loop_snode = false;
}
}
}
if (same_as_loop_snode)
stmt->activate = false;
}
}
}

private:
OffloadedStmt *current_offload;
};

namespace irpass {

void flag_access(IRNode *root) {
FlagAccess instance(root);
FlagAccess flag_access(root);
WeakenAccess weaken_access(root);
}

} // namespace irpass
Expand Down