Skip to content

Commit

Permalink
[async] Add element-wise info into TaskMeta for fusion (#1884)
Browse files Browse the repository at this point in the history
* Refactor AsyncStateHash to std::hash

* Add TaskMeta::element_wise

* Make use of it in fuse()

* Fix false positive when task type is serial

* [skip ci] Apply suggestions from code review

Co-authored-by: Yuanming Hu <yuanming-hu@users.noreply.github.com>

* [skip ci] enforce code format

* retrigger CI

Co-authored-by: Yuanming Hu <yuanming-hu@users.noreply.github.com>
Co-authored-by: Taichi Gardener <taichigardener@gmail.com>
  • Loading branch information
3 people authored Sep 21, 2020
1 parent f2a798d commit 9a23ecd
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 43 deletions.
22 changes: 22 additions & 0 deletions taichi/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,28 @@ GlobalPtrStmt::GlobalPtrStmt(const LaneAttribute<SNode *> &snodes,
TI_STMT_REG_FIELDS;
}

bool GlobalPtrStmt::is_element_wise(SNode *snode) const {
if (snode == nullptr) {
// check every SNode when "snode" is nullptr
for (const auto &snode_i : snodes.data) {
if (!is_element_wise(snode_i)) {
return false;
}
}
return true;
}
// check if this statement is element-wise on a specific SNode, i.e., argument
// "snode"
for (int i = 0; i < (int)indices.size(); i++) {
if (auto loop_index_i = indices[i]->cast<LoopIndexStmt>();
!(loop_index_i && loop_index_i->loop->is<OffloadedStmt>() &&
loop_index_i->index == snode->physical_index_position[i])) {
return false;
}
}
return true;
}

std::string GlobalPtrExpression::serialize() {
std::string s = fmt::format("{}[", var.serialize());
for (int i = 0; i < (int)indices.size(); i++) {
Expand Down
2 changes: 2 additions & 0 deletions taichi/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,8 @@ class GlobalPtrStmt : public Stmt {
const std::vector<Stmt *> &indices,
bool activate = true);

bool is_element_wise(SNode *snode) const;

bool has_global_side_effect() const override {
return activate;
}
Expand Down
9 changes: 9 additions & 0 deletions taichi/program/async_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,15 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) {
}
}
}
for (auto &snode : ptr->snodes.data) {
if (ptr->is_element_wise(snode)) {
if (meta.element_wise.find(snode) == meta.element_wise.end()) {
meta.element_wise[snode] = true;
}
} else {
meta.element_wise[snode] = false;
}
}
}
if (auto clear_list = stmt->cast<ClearListStmt>()) {
meta.output_states.emplace(clear_list->snode, AsyncState::Type::list);
Expand Down
22 changes: 22 additions & 0 deletions taichi/program/async_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,28 @@ void TaskMeta::print() const {
}
fmt::print("\n");
}
std::vector<SNode *> element_wise_snodes, non_element_wise_snodes;
for (auto s : element_wise) {
if (s.second) {
element_wise_snodes.push_back(s.first);
} else {
non_element_wise_snodes.push_back(s.first);
}
}
if (!element_wise_snodes.empty()) {
fmt::print(" element-wise snodes:\n ");
for (auto s : element_wise_snodes) {
fmt::print("{} ", s->get_node_type_name_hinted());
}
fmt::print("\n");
}
if (!non_element_wise_snodes.empty()) {
fmt::print(" non-element-wise snodes:\n ");
for (auto s : non_element_wise_snodes) {
fmt::print("{} ", s->get_node_type_name_hinted());
}
fmt::print("\n");
}
}

TLANG_NAMESPACE_END
60 changes: 31 additions & 29 deletions taichi/program/async_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,29 +50,6 @@ class IRHandle {
uint64 hash_;
};

TLANG_NAMESPACE_END

namespace std {
template <>
struct hash<taichi::lang::IRHandle> {
std::size_t operator()(const taichi::lang::IRHandle &ir_handle) const
noexcept {
return ir_handle.hash();
}
};

template <>
struct hash<std::pair<taichi::lang::IRHandle, taichi::lang::IRHandle>> {
std::size_t operator()(
const std::pair<taichi::lang::IRHandle, taichi::lang::IRHandle>
&ir_handles) const noexcept {
return ir_handles.first.hash() * 100000007UL + ir_handles.second.hash();
}
};
} // namespace std

TLANG_NAMESPACE_BEGIN

// Records the necessary data for launching an offloaded task.
class TaskLaunchRecord {
public:
Expand Down Expand Up @@ -127,19 +104,44 @@ struct AsyncState {
}
};

class AsyncStateHash {
public:
size_t operator()(const AsyncState &s) const {
return (uint64)s.snode ^ (uint64)s.type;
TLANG_NAMESPACE_END

namespace std {
template <>
struct hash<taichi::lang::IRHandle> {
std::size_t operator()(const taichi::lang::IRHandle &ir_handle) const
noexcept {
return ir_handle.hash();
}
};

template <>
struct hash<std::pair<taichi::lang::IRHandle, taichi::lang::IRHandle>> {
std::size_t operator()(
const std::pair<taichi::lang::IRHandle, taichi::lang::IRHandle>
&ir_handles) const noexcept {
return ir_handles.first.hash() * 100000007UL + ir_handles.second.hash();
}
};

template <>
struct hash<taichi::lang::AsyncState> {
std::size_t operator()(const taichi::lang::AsyncState &s) const noexcept {
return (std::size_t)s.snode ^ (std::size_t)s.type;
}
};

} // namespace std

TLANG_NAMESPACE_BEGIN

struct TaskMeta {
std::string name;
OffloadedStmt::TaskType type{OffloadedStmt::TaskType::serial};
SNode *snode{nullptr}; // struct-for and listgen only
std::unordered_set<AsyncState, AsyncStateHash> input_states;
std::unordered_set<AsyncState, AsyncStateHash> output_states;
std::unordered_set<AsyncState> input_states;
std::unordered_set<AsyncState> output_states;
std::unordered_map<SNode *, bool> element_wise;

void print() const;
};
Expand Down
39 changes: 30 additions & 9 deletions taichi/program/state_flow_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,11 @@ bool StateFlowGraph::fuse() {
std::vector<Bitset> has_path, has_path_reverse;
std::tie(has_path, has_path_reverse) = compute_transitive_closure();

// Cache the result that if each pair is fusable by task types.
// Cache the result that if each pair is fusible by task types.
// TODO: improve this
auto task_type_fusable = std::make_unique<Bitset[]>(n);
auto task_type_fusible = std::make_unique<Bitset[]>(n);
for (int i = 0; i < n; i++) {
task_type_fusable[i] = Bitset(n);
task_type_fusible[i] = Bitset(n);
}
// nodes_[0] is the initial node.
for (int i = 1; i < n; i++) {
Expand Down Expand Up @@ -292,10 +292,10 @@ bool StateFlowGraph::fuse() {
// TODO: avoid snode accessors going into async engine
const bool is_snode_accessor =
(rec_i.kernel->is_accessor || rec_j.kernel->is_accessor);
bool fusable =
bool fusible =
(is_same_range_for || is_same_struct_for || are_both_serial) &&
kernel_args_match && !is_snode_accessor;
task_type_fusable[i][j] = fusable;
task_type_fusible[i][j] = fusible;
}
}

Expand Down Expand Up @@ -345,6 +345,29 @@ bool StateFlowGraph::fuse() {

auto fused = std::make_unique<bool[]>(n);

auto edge_fusible = [&](int a, int b) {
// Check if a and b are fusible if there is an edge (a, b).
if (fused[a] || fused[b] || !task_type_fusible[a][b]) {
return false;
}
if (nodes_[a]->meta->type == OffloadedStmt::TaskType::serial) {
return true;
}
for (auto &state : nodes_[a]->output_edges) {
if (state.first.type != AsyncState::Type::value) {
// TODO: What checks do we need for edges of mask/list states?
continue;
}
if (state.second.find(nodes_[b].get()) != state.second.end()) {
if (!nodes_[a]->meta->element_wise[state.first.snode] ||
!nodes_[b]->meta->element_wise[state.first.snode]) {
return false;
}
}
}
return true;
};

bool modified = false;
while (true) {
bool updated = false;
Expand All @@ -357,9 +380,7 @@ bool StateFlowGraph::fuse() {
for (auto &edges : nodes_[i]->output_edges) {
for (auto &edge : edges.second) {
const int j = edge->node_id;
// TODO: for each pair of edge (i, j), we can only fuse if they
// are both serial or both element-wise.
if (!fused[j] && task_type_fusable[i][j]) {
if (edge_fusible(i, j)) {
auto i_has_path_to_j = has_path[i] & has_path_reverse[j];
i_has_path_to_j[i] = i_has_path_to_j[j] = false;
// check if i doesn't have a path to j of length >= 2
Expand All @@ -381,7 +402,7 @@ bool StateFlowGraph::fuse() {
for (int i = 1; i < n; i++) {
if (!fused[i]) {
for (int j = i + 1; j < n; j++) {
if (!fused[j] && task_type_fusable[i][j] && !has_path[i][j] &&
if (!fused[j] && task_type_fusible[i][j] && !has_path[i][j] &&
!has_path[j][i]) {
do_fuse(i, j);
fused[i] = fused[j] = true;
Expand Down
9 changes: 4 additions & 5 deletions taichi/program/state_flow_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ class IRBank;
class StateFlowGraph {
public:
struct Node;
using StateToNodeMapping =
std::unordered_map<AsyncState, Node *, AsyncStateHash>;
using StateToNodeMapping = std::unordered_map<AsyncState, Node *>;

// Each node is a task
// Note: after SFG is done, each node here should hold a TaskLaunchRecord.
Expand All @@ -39,8 +38,8 @@ class StateFlowGraph {

// Profiling showed horrible performance using std::unordered_multimap (at
// least on Mac with clang-1103.0.32.62)...
std::unordered_map<AsyncState, std::unordered_set<Node *>, AsyncStateHash>
output_edges, input_edges;
std::unordered_map<AsyncState, std::unordered_set<Node *>> output_edges,
input_edges;

std::string string() const;

Expand Down Expand Up @@ -140,7 +139,7 @@ class StateFlowGraph {
Node *initial_node_; // The initial node holds all the initial states.
TaskMeta initial_meta_;
StateToNodeMapping latest_state_owner_;
std::unordered_map<AsyncState, std::unordered_set<Node *>, AsyncStateHash>
std::unordered_map<AsyncState, std::unordered_set<Node *>>
latest_state_readers_;
std::unordered_map<std::string, int> task_name_to_launch_ids_;
IRBank *ir_bank_;
Expand Down

0 comments on commit 9a23ecd

Please sign in to comment.