Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[async] Add element-wise info into TaskMeta for fusion #1884

Merged
merged 7 commits into from
Sep 21, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions taichi/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,27 @@ GlobalPtrStmt::GlobalPtrStmt(const LaneAttribute<SNode *> &snodes,
TI_STMT_REG_FIELDS;
}

bool GlobalPtrStmt::is_element_wise(SNode *snode) const {
if (snode == nullptr) {
// check all snodes
xumingkuan marked this conversation as resolved.
Show resolved Hide resolved
for (const auto &snode_i : snodes.data) {
if (!is_element_wise(snode_i)) {
return false;
}
}
return true;
}
// check if this statement is element-wise on snode
xumingkuan marked this conversation as resolved.
Show resolved Hide resolved
for (int i = 0; i < (int)indices.size(); i++) {
if (auto loop_index_i = indices[i]->cast<LoopIndexStmt>();
!(loop_index_i && loop_index_i->loop->is<OffloadedStmt>() &&
loop_index_i->index == snode->physical_index_position[i])) {
return false;
}
}
return true;
}

std::string GlobalPtrExpression::serialize() {
std::string s = fmt::format("{}[", var.serialize());
for (int i = 0; i < (int)indices.size(); i++) {
Expand Down
2 changes: 2 additions & 0 deletions taichi/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,8 @@ class GlobalPtrStmt : public Stmt {
const std::vector<Stmt *> &indices,
bool activate = true);

bool is_element_wise(SNode *snode) const;

bool has_global_side_effect() const override {
return activate;
}
Expand Down
9 changes: 9 additions & 0 deletions taichi/program/async_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,15 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) {
}
}
}
for (auto &snode : ptr->snodes.data) {
if (ptr->is_element_wise(snode)) {
if (meta.element_wise.find(snode) == meta.element_wise.end()) {
meta.element_wise[snode] = true;
}
} else {
meta.element_wise[snode] = false;
}
}
}
if (auto clear_list = stmt->cast<ClearListStmt>()) {
meta.output_states.emplace(clear_list->snode, AsyncState::Type::list);
Expand Down
22 changes: 22 additions & 0 deletions taichi/program/async_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,28 @@ void TaskMeta::print() const {
}
fmt::print("\n");
}
std::vector<SNode *> element_wise_snodes, non_element_wise_snodes;
for (auto s : element_wise) {
if (s.second) {
element_wise_snodes.push_back(s.first);
} else {
non_element_wise_snodes.push_back(s.first);
}
}
if (!element_wise_snodes.empty()) {
fmt::print(" element-wise snodes:\n ");
for (auto s : element_wise_snodes) {
fmt::print("{} ", s->get_node_type_name_hinted());
}
fmt::print("\n");
}
if (!non_element_wise_snodes.empty()) {
fmt::print(" non-element-wise snodes:\n ");
for (auto s : non_element_wise_snodes) {
fmt::print("{} ", s->get_node_type_name_hinted());
}
fmt::print("\n");
}
}

TLANG_NAMESPACE_END
60 changes: 31 additions & 29 deletions taichi/program/async_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,29 +50,6 @@ class IRHandle {
uint64 hash_;
};

TLANG_NAMESPACE_END

namespace std {
template <>
struct hash<taichi::lang::IRHandle> {
std::size_t operator()(const taichi::lang::IRHandle &ir_handle) const
noexcept {
return ir_handle.hash();
}
};

template <>
struct hash<std::pair<taichi::lang::IRHandle, taichi::lang::IRHandle>> {
std::size_t operator()(
const std::pair<taichi::lang::IRHandle, taichi::lang::IRHandle>
&ir_handles) const noexcept {
return ir_handles.first.hash() * 100000007UL + ir_handles.second.hash();
}
};
} // namespace std

TLANG_NAMESPACE_BEGIN

// Records the necessary data for launching an offloaded task.
class TaskLaunchRecord {
public:
Expand Down Expand Up @@ -127,19 +104,44 @@ struct AsyncState {
}
};

class AsyncStateHash {
public:
size_t operator()(const AsyncState &s) const {
return (uint64)s.snode ^ (uint64)s.type;
TLANG_NAMESPACE_END

namespace std {
template <>
struct hash<taichi::lang::IRHandle> {
std::size_t operator()(const taichi::lang::IRHandle &ir_handle) const
noexcept {
return ir_handle.hash();
}
};

template <>
struct hash<std::pair<taichi::lang::IRHandle, taichi::lang::IRHandle>> {
std::size_t operator()(
const std::pair<taichi::lang::IRHandle, taichi::lang::IRHandle>
&ir_handles) const noexcept {
return ir_handles.first.hash() * 100000007UL + ir_handles.second.hash();
}
};

template <>
struct hash<taichi::lang::AsyncState> {
std::size_t operator()(const taichi::lang::AsyncState &s) const noexcept {
return (std::size_t)s.snode ^ (std::size_t)s.type;
}
};

} // namespace std

TLANG_NAMESPACE_BEGIN

struct TaskMeta {
std::string name;
OffloadedStmt::TaskType type{OffloadedStmt::TaskType::serial};
SNode *snode{nullptr}; // struct-for and listgen only
std::unordered_set<AsyncState, AsyncStateHash> input_states;
std::unordered_set<AsyncState, AsyncStateHash> output_states;
std::unordered_set<AsyncState> input_states;
std::unordered_set<AsyncState> output_states;
std::unordered_map<SNode *, bool> element_wise;

void print() const;
};
Expand Down
39 changes: 30 additions & 9 deletions taichi/program/state_flow_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,11 @@ bool StateFlowGraph::fuse() {
std::vector<Bitset> has_path, has_path_reverse;
std::tie(has_path, has_path_reverse) = compute_transitive_closure();

// Cache the result that if each pair is fusable by task types.
// Cache the result that if each pair is fusible by task types.
// TODO: improve this
auto task_type_fusable = std::make_unique<Bitset[]>(n);
auto task_type_fusible = std::make_unique<Bitset[]>(n);
for (int i = 0; i < n; i++) {
task_type_fusable[i] = Bitset(n);
task_type_fusible[i] = Bitset(n);
}
// nodes_[0] is the initial node.
for (int i = 1; i < n; i++) {
Expand Down Expand Up @@ -292,10 +292,10 @@ bool StateFlowGraph::fuse() {
// TODO: avoid snode accessors going into async engine
const bool is_snode_accessor =
(rec_i.kernel->is_accessor || rec_j.kernel->is_accessor);
bool fusable =
bool fusible =
(is_same_range_for || is_same_struct_for || are_both_serial) &&
kernel_args_match && !is_snode_accessor;
task_type_fusable[i][j] = fusable;
task_type_fusible[i][j] = fusible;
}
}

Expand Down Expand Up @@ -345,6 +345,29 @@ bool StateFlowGraph::fuse() {

auto fused = std::make_unique<bool[]>(n);

auto edge_fusible = [&](int a, int b) {
// Check if a and b are fusible if there is an edge (a, b).
if (fused[a] || fused[b] || !task_type_fusible[a][b]) {
return false;
}
if (nodes_[a]->meta->type == OffloadedStmt::TaskType::serial) {
return true;
}
for (auto &state : nodes_[a]->output_edges) {
if (state.first.type != AsyncState::Type::value) {
// TODO: What checks do we need for edges of mask/list states?
continue;
}
if (state.second.find(nodes_[b].get()) != state.second.end()) {
if (!nodes_[a]->meta->element_wise[state.first.snode] ||
!nodes_[b]->meta->element_wise[state.first.snode]) {
return false;
}
}
}
return true;
};

bool modified = false;
while (true) {
bool updated = false;
Expand All @@ -357,9 +380,7 @@ bool StateFlowGraph::fuse() {
for (auto &edges : nodes_[i]->output_edges) {
for (auto &edge : edges.second) {
const int j = edge->node_id;
// TODO: for each pair of edge (i, j), we can only fuse if they
// are both serial or both element-wise.
if (!fused[j] && task_type_fusable[i][j]) {
if (edge_fusible(i, j)) {
auto i_has_path_to_j = has_path[i] & has_path_reverse[j];
i_has_path_to_j[i] = i_has_path_to_j[j] = false;
// check if i doesn't have a path to j of length >= 2
Expand All @@ -381,7 +402,7 @@ bool StateFlowGraph::fuse() {
for (int i = 1; i < n; i++) {
if (!fused[i]) {
for (int j = i + 1; j < n; j++) {
if (!fused[j] && task_type_fusable[i][j] && !has_path[i][j] &&
if (!fused[j] && task_type_fusible[i][j] && !has_path[i][j] &&
!has_path[j][i]) {
do_fuse(i, j);
fused[i] = fused[j] = true;
Expand Down
9 changes: 4 additions & 5 deletions taichi/program/state_flow_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ class IRBank;
class StateFlowGraph {
public:
struct Node;
using StateToNodeMapping =
std::unordered_map<AsyncState, Node *, AsyncStateHash>;
using StateToNodeMapping = std::unordered_map<AsyncState, Node *>;

// Each node is a task
// Note: after SFG is done, each node here should hold a TaskLaunchRecord.
Expand All @@ -39,8 +38,8 @@ class StateFlowGraph {

// Profiling showed horrible performance using std::unordered_multimap (at
// least on Mac with clang-1103.0.32.62)...
std::unordered_map<AsyncState, std::unordered_set<Node *>, AsyncStateHash>
output_edges, input_edges;
std::unordered_map<AsyncState, std::unordered_set<Node *>> output_edges,
input_edges;

std::string string() const;

Expand Down Expand Up @@ -140,7 +139,7 @@ class StateFlowGraph {
Node *initial_node_; // The initial node holds all the initial states.
TaskMeta initial_meta_;
StateToNodeMapping latest_state_owner_;
std::unordered_map<AsyncState, std::unordered_set<Node *>, AsyncStateHash>
std::unordered_map<AsyncState, std::unordered_set<Node *>>
latest_state_readers_;
std::unordered_map<std::string, int> task_name_to_launch_ids_;
IRBank *ir_bank_;
Expand Down