Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[async] Support SFG-level DSE for scalar SNodes #1907

Merged
merged 13 commits into from
Oct 2, 2020
35 changes: 28 additions & 7 deletions taichi/ir/control_flow_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,11 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) {
replace_with(i, std::move(local_load), true);
modified = true;
continue;
} else if (!is_parallel_executed) {
} else if (!is_parallel_executed ||
(atomic->dest->is<GlobalPtrStmt>() &&
atomic->dest->as<GlobalPtrStmt>()
->snodes[0]
->is_scalar())) {
// If this node is parallel executed, we can't weaken a global
// atomic operation to a global load.
// TODO: we can weaken it if it's element-wise (i.e. never
Expand Down Expand Up @@ -619,22 +623,37 @@ void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access) {
}
}

void ControlFlowGraph::live_variable_analysis(bool after_lower_access) {
void ControlFlowGraph::live_variable_analysis(
bool after_lower_access,
const std::optional<LiveVarAnalysisConfig> &config_opt) {
TI_AUTO_PROF;
const int num_nodes = size();
std::queue<CFGNode *> to_visit;
std::unordered_map<CFGNode *, bool> in_queue;
TI_ASSERT(nodes[final_node]->empty());
nodes[final_node]->live_gen.clear();
nodes[final_node]->live_kill.clear();

auto in_final_node_live_gen = [&config_opt](const Stmt *stmt) -> bool {
if (stmt->is<AllocaStmt>() || stmt->is<StackAllocaStmt>()) {
return false;
}
if (auto *gptr = stmt->cast<GlobalPtrStmt>();
gptr && config_opt.has_value()) {
TI_ASSERT(gptr->snodes.size() == 1);
const bool res =
(config_opt->eliminable_snodes.count(gptr->snodes[0]) == 0);
return res;
}
// A global pointer that may be loaded after this kernel.
return true;
};
if (!after_lower_access) {
for (int i = 0; i < num_nodes; i++) {
for (int j = nodes[i]->begin_location; j < nodes[i]->end_location; j++) {
auto stmt = nodes[i]->block->statements[j].get();
for (auto store_ptr : irpass::analysis::get_store_destination(stmt)) {
if (!store_ptr->is<AllocaStmt>() &&
!store_ptr->is<StackAllocaStmt>()) {
// A global pointer that may be loaded after this kernel.
if (in_final_node_live_gen(store_ptr)) {
nodes[final_node]->live_gen.insert(store_ptr);
}
}
Expand Down Expand Up @@ -753,9 +772,11 @@ bool ControlFlowGraph::store_to_load_forwarding(bool after_lower_access) {
return modified;
}

bool ControlFlowGraph::dead_store_elimination(bool after_lower_access) {
bool ControlFlowGraph::dead_store_elimination(
bool after_lower_access,
const std::optional<LiveVarAnalysisConfig> &lva_config_opt) {
TI_AUTO_PROF;
live_variable_analysis(after_lower_access);
live_variable_analysis(after_lower_access, lva_config_opt);
const int num_nodes = size();
bool modified = false;
for (int i = 0; i < num_nodes; i++) {
Expand Down
17 changes: 15 additions & 2 deletions taichi/ir/control_flow_graph.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#pragma once

#include <optional>
#include <unordered_set>

#include "taichi/ir/ir.h"

TLANG_NAMESPACE_BEGIN
Expand Down Expand Up @@ -70,6 +73,12 @@ class ControlFlowGraph {
void erase(int node_id);

public:
struct LiveVarAnalysisConfig {
// This is mostly useful for SFG task-level dead store elimination. SFG may
// detect certain cases where writes to one or more SNodes in a task are
// eliminable.
std::unordered_set<const SNode *> eliminable_snodes;
};
std::vector<std::unique_ptr<CFGNode>> nodes;
const int start_node = 0;
int final_node{0};
Expand All @@ -85,7 +94,9 @@ class ControlFlowGraph {

void print_graph_structure() const;
void reaching_definition_analysis(bool after_lower_access);
void live_variable_analysis(bool after_lower_access);
void live_variable_analysis(
bool after_lower_access,
const std::optional<LiveVarAnalysisConfig> &config_opt);

void simplify_graph();

Expand All @@ -96,7 +107,9 @@ class ControlFlowGraph {
bool store_to_load_forwarding(bool after_lower_access);

// Also performs identical load elimination.
bool dead_store_elimination(bool after_lower_access);
bool dead_store_elimination(
bool after_lower_access,
const std::optional<LiveVarAnalysisConfig> &lva_config_opt);
};

TLANG_NAMESPACE_END
4 changes: 4 additions & 0 deletions taichi/ir/snode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ bool SNode::is_place() const {
return type == SNodeType::place;
}

bool SNode::is_scalar() const {
return is_place() && (num_active_indices == 0);
}

bool SNode::has_grad() const {
auto adjoint = expr.cast<GlobalVariableExpression>()->adjoint;
return is_primal() && adjoint.expr != nullptr &&
Expand Down
2 changes: 2 additions & 0 deletions taichi/ir/snode.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,8 @@ class SNode {

bool is_place() const;

bool is_scalar() const;

const Expr &get_expr() const {
return expr;
}
Expand Down
13 changes: 10 additions & 3 deletions taichi/ir/transforms.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
#pragma once

#include "taichi/ir/ir.h"
#include <atomic>
#include <unordered_set>
#include <optional>
#include <unordered_map>
#include <unordered_set>

#include "taichi/ir/control_flow_graph.h"
#include "taichi/ir/ir.h"

TLANG_NAMESPACE_BEGIN

Expand All @@ -19,7 +22,11 @@ void re_id(IRNode *root);
void flag_access(IRNode *root);
bool die(IRNode *root);
bool simplify(IRNode *root, Kernel *kernel = nullptr);
bool cfg_optimization(IRNode *root, bool after_lower_access);
bool cfg_optimization(
IRNode *root,
bool after_lower_access,
const std::optional<ControlFlowGraph::LiveVarAnalysisConfig>
&lva_config_opt = std::nullopt);
bool alg_simp(IRNode *root);
bool demote_operations(IRNode *root);
bool binary_op_simplify(IRNode *root);
Expand Down
5 changes: 4 additions & 1 deletion taichi/program/async_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,10 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) {
if (auto global_store = stmt->cast<GlobalStoreStmt>()) {
if (auto ptr = global_store->ptr->cast<GlobalPtrStmt>()) {
for (auto &snode : ptr->snodes.data) {
meta.input_states.emplace(snode, AsyncState::Type::value);
if (!snode->is_scalar()) {
// TODO: This is ad-hoc, use value killing analysis
meta.input_states.emplace(snode, AsyncState::Type::value);
}
meta.output_states.emplace(snode, AsyncState::Type::value);
}
}
Expand Down
37 changes: 37 additions & 0 deletions taichi/program/ir_bank.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,4 +189,41 @@ IRHandle IRBank::demote_activation(IRHandle handle) {
return result;
}

std::pair<IRHandle, bool> IRBank::optimize_dse(
IRHandle handle,
const std::set<const SNode *> &snodes,
bool verbose) {
const OptimizeDseKey key(handle, snodes);
auto &ret_handle = optimize_dse_bank_[key];
if (!ret_handle.empty()) {
// Already cached
return std::make_pair(ret_handle, true);
}

std::unique_ptr<IRNode> new_ir = handle.clone();

if (verbose) {
TI_INFO(" DSE: before CFG");
irpass::print(new_ir.get());
}
ControlFlowGraph::LiveVarAnalysisConfig lva_config;
lva_config.eliminable_snodes = {snodes.begin(), snodes.end()};
const bool modified = irpass::cfg_optimization(
new_ir.get(), /*after_lower_access=*/false, lva_config);
if (verbose) {
TI_INFO(" DSE: after CFG, modified={}", modified);
irpass::print(new_ir.get());
}

if (!modified) {
// Nothing demoted. Simply delete new_ir when this function returns.
ret_handle = handle;
return std::make_pair(ret_handle, false);
}

ret_handle = IRHandle(new_ir.get(), get_hash(new_ir.get()));
insert(std::move(new_ir), ret_handle.hash());
return std::make_pair(ret_handle, false);
}

TLANG_NAMESPACE_END
49 changes: 49 additions & 0 deletions taichi/program/ir_bank.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
#include <set>
#include <unordered_map>
#include <vector>

#include "taichi/program/async_utils.h"

TLANG_NAMESPACE_BEGIN
Expand All @@ -16,6 +20,16 @@ class IRBank {

IRHandle demote_activation(IRHandle handle);

// Try running DSE optimization on the IR identified by |handle|. |snodes|
// denotes the set of SNodes whose stores are safe to eliminate.
//
// Returns:
// * IRHandle: the (possibly) DSE-optimized IRHandle
// * bool: whether the result is already cached.
std::pair<IRHandle, bool> optimize_dse(IRHandle handle,
const std::set<const SNode *> &snodes,
bool verbose);

std::unordered_map<IRHandle, TaskMeta> meta_bank_;
std::unordered_map<IRHandle, TaskFusionMeta> fusion_meta_bank_;

Expand All @@ -25,6 +39,41 @@ class IRBank {
std::vector<std::unique_ptr<IRNode>> trash_bin; // prevent IR from deleted
std::unordered_map<std::pair<IRHandle, IRHandle>, IRHandle> fuse_bank_;
std::unordered_map<IRHandle, IRHandle> demote_activation_bank_;

// For DSE optimization, the input key is (IRHandle, [SNode*]). This is
// because it is possible that the same IRHandle may have different sets of
// SNode stores that are eliminable.
struct OptimizeDseKey {
IRHandle task_ir;
// Intentionally use (ordered) set so that hash is deterministic.
std::set<const SNode *> eliminable_snodes;

OptimizeDseKey(const IRHandle task_ir,
const std::set<const SNode *> &snodes)
: task_ir(task_ir), eliminable_snodes(snodes) {
}

bool operator==(const OptimizeDseKey &other) const {
return (task_ir == other.task_ir) &&
(eliminable_snodes == other.eliminable_snodes);
}

bool operator!=(const OptimizeDseKey &other) const {
return !(*this == other);
}

struct Hash {
std::size_t operator()(const OptimizeDseKey &k) const {
std::size_t ret = k.task_ir.hash();
for (const auto *s : k.eliminable_snodes) {
ret = ret * 100000007UL + reinterpret_cast<uintptr_t>(s);
}
return ret;
}
};
};
std::unordered_map<OptimizeDseKey, IRHandle, OptimizeDseKey::Hash>
optimize_dse_bank_;
};

TLANG_NAMESPACE_END
Loading