Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[async] [IR] More accurate same_value analysis #2118

Merged
merged 11 commits into from
Dec 27, 2020
115 changes: 98 additions & 17 deletions taichi/analysis/same_statements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#include "taichi/ir/analysis.h"
#include "taichi/ir/statements.h"
#include "taichi/ir/visitors.h"
#include "taichi/program/async_utils.h"
#include "taichi/program/ir_bank.h"
#include <unordered_map>
#include <unordered_set>
#include <utility>
Expand All @@ -18,23 +20,44 @@ class IRNodeComparator : public IRVisitor {
bool recursively_check_;
bool check_same_value_;
xumingkuan marked this conversation as resolved.
Show resolved Hide resolved

std::unordered_set<AsyncState> possibly_modified_states_;
bool all_states_can_be_modified_;
IRBank *ir_bank_;

public:
bool same;

explicit IRNodeComparator(IRNode *other_node,
std::optional<std::unordered_map<int, int>> id_map,
bool check_same_value)
explicit IRNodeComparator(
IRNode *other_node,
const std::optional<std::unordered_map<int, int>> &id_map,
bool check_same_value,
const std::optional<std::unordered_set<AsyncState>>
&possibly_modified_states,
IRBank *ir_bank)
: other_node(other_node) {
allow_undefined_visitor = true;
invoke_default_visitor = true;
same = true;
if (id_map.has_value()) {
recursively_check_ = true;
this->id_map = std::move(id_map.value());
this->id_map = id_map.value();
} else {
recursively_check_ = false;
}
if (possibly_modified_states.has_value()) {
TI_ASSERT_INFO(check_same_value,
"The parameter possibly_modified_states "
"is only supported when check_same_value is true");
TI_ASSERT_INFO(ir_bank,
"The parameter possibly_modified_states "
"requires ir_bank")
all_states_can_be_modified_ = false;
this->possibly_modified_states_ = possibly_modified_states.value();
} else {
all_states_can_be_modified_ = true;
}
check_same_value_ = check_same_value;
ir_bank_ = ir_bank;
}

void map_id(int this_id, int other_id) {
Expand Down Expand Up @@ -97,22 +120,57 @@ class IRNodeComparator : public IRVisitor {
same = false;
return;
}
auto other = other_node->as<Stmt>();

// If two identical statements can have different values, return false.
if (check_same_value_ && !stmt->is_container_statement() &&
!stmt->common_statement_eliminable()) {
same = false;
return;
// TODO: two identical GlobalPtrStmts cannot have different values,
// but GlobalPtrStmt::common_statement_eliminable() is false.
xumingkuan marked this conversation as resolved.
Show resolved Hide resolved
if (check_same_value_ && stmt != other && !stmt->is_container_statement() &&
k-ye marked this conversation as resolved.
Show resolved Hide resolved
!stmt->common_statement_eliminable() && !stmt->is<GlobalPtrStmt>()) {
if (all_states_can_be_modified_) {
same = false;
return;
} else {
bool same_value = false;
if (auto global_load = stmt->cast<GlobalLoadStmt>()) {
if (auto global_ptr = global_load->ptr->cast<GlobalPtrStmt>()) {
TI_ASSERT(global_ptr->width() == 1);
if (possibly_modified_states_.count(ir_bank_->get_async_state(
global_ptr->snodes[0], AsyncState::Type::value)) == 0) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, previous we were too conservative, and returned "not the same" even for snodes that are only read in the tasks?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly.

same_value = true;
}
}
// TODO: other cases?
}
if (!same_value) {
same = false;
return;
}
}
}
// Note that we do not need to test !stmt2->common_statement_eliminable()
// because if this condition does not hold,
// same_statements(stmt1, stmt2) returns false anyway.

// field check
xumingkuan marked this conversation as resolved.
Show resolved Hide resolved
auto other = other_node->as<Stmt>();
if (!stmt->field_manager.equal(other->field_manager)) {
same = false;
return;
if (check_same_value_ && stmt->is<GlobalPtrStmt>()) {
// Special case: we do not care the "activate" field when checking
// whether two global pointers share the same value.
// And we cannot use irpass::analysis::definitely_same_address()
// directly because that function does not support id_map.

// TODO: Update this part if GlobalPtrStmt comes to have more fields
TI_ASSERT(stmt->width() == 1);
if (stmt->as<GlobalPtrStmt>()->snodes[0]->id !=
other->as<GlobalPtrStmt>()->snodes[0]->id) {
same = false;
return;
}
} else {
if (!stmt->field_manager.equal(other->field_manager)) {
same = false;
return;
}
}

// operand check
Expand Down Expand Up @@ -219,8 +277,12 @@ class IRNodeComparator : public IRVisitor {
static bool run(IRNode *root1,
IRNode *root2,
const std::optional<std::unordered_map<int, int>> &id_map,
bool check_same_value) {
IRNodeComparator comparator(root2, id_map, check_same_value);
bool check_same_value,
const std::optional<std::unordered_set<AsyncState>>
xumingkuan marked this conversation as resolved.
Show resolved Hide resolved
&possibly_modified_states,
IRBank *ir_bank) {
IRNodeComparator comparator(root2, id_map, check_same_value,
possibly_modified_states, ir_bank);
root1->accept(&comparator);
return comparator.same;
}
Expand Down Expand Up @@ -270,17 +332,36 @@ bool same_statements(
if (!root1 || !root2)
return false;
return IRNodeComparator::run(root1, root2, id_map,
/*check_same_value=*/false);
/*check_same_value=*/false, std::nullopt,
/*ir_bank=*/nullptr);
}
bool same_value(Stmt *stmt1,
Stmt *stmt2,
const AsyncStateSet &possibly_modified_states,
IRBank *ir_bank,
const std::optional<std::unordered_map<int, int>> &id_map) {
// Test if two statements definitely have the same value.
if (stmt1 == stmt2)
return true;
if (!stmt1 || !stmt2)
return false;
return IRNodeComparator::run(
stmt1, stmt2, id_map, /*check_same_value=*/true,
std::make_optional<std::unordered_set<AsyncState>>(
possibly_modified_states.s),
ir_bank);
}
bool same_value(Stmt *stmt1,
Stmt *stmt2,
const std::optional<std::unordered_map<int, int>> &id_map) {
// Test if two statements must have the same value.
// Test if two statements definitely have the same value.
if (stmt1 == stmt2)
return true;
if (!stmt1 || !stmt2)
return false;
return IRNodeComparator::run(stmt1, stmt2, id_map, /*check_same_value=*/true);
return IRNodeComparator::run(stmt1, stmt2, id_map,
/*check_same_value=*/true, std::nullopt,
/*ir_bank=*/nullptr);
}
} // namespace irpass::analysis

Expand Down
41 changes: 33 additions & 8 deletions taichi/ir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class ControlFlowGraph;

struct TaskMeta;
class IRBank;
class AsyncStateSet;

// IR Analysis
namespace irpass::analysis {
Expand Down Expand Up @@ -82,20 +83,44 @@ std::vector<Stmt *> get_store_destination(Stmt *store_stmt);
bool has_store_or_atomic(IRNode *root, const std::vector<Stmt *> &vars);
std::pair<bool, Stmt *> last_store_or_atomic(IRNode *root, Stmt *var);
bool maybe_same_address(Stmt *var1, Stmt *var2);
/** Test if root1 and root2 are the same, i.e., have the same type,
* the same operands, the same fields, and the same containing statements.
/**
* Test if root1 and root2 are the same, i.e., have the same type,
* the same operands, the same fields, and the same containing statements.
*
* @param id_map
* If id_map is std::nullopt by default, two operands are considered
* the same if they have the same id and do not belong to either root,
* or they belong to root1 and root2 at the same position in the roots.
* Otherwise, this function also recursively check the operands until
* ids in the id_map are reached.
* @param id_map
* If id_map is std::nullopt by default, two operands are considered
* the same if they have the same id and do not belong to either root,
* or they belong to root1 and root2 at the same position in the roots.
* Otherwise, this function also recursively check the operands until
* ids in the id_map are reached.
*/
bool same_statements(
IRNode *root1,
IRNode *root2,
const std::optional<std::unordered_map<int, int>> &id_map = std::nullopt);
/**
* Test if stmt1 and stmt2 definitely have the same value.
*
* @param possibly_modified_states
* Only states in possibly_modified_states can be modified
xumingkuan marked this conversation as resolved.
Show resolved Hide resolved
* between stmt1 and stmt2.
*
* @param id_map
* Same as in same_statements(root1, root2, id_map).
*/
bool same_value(
Stmt *stmt1,
Stmt *stmt2,
const AsyncStateSet &possibly_modified_states,
IRBank *ir_bank,
const std::optional<std::unordered_map<int, int>> &id_map = std::nullopt);
/**
* Test if stmt1 and stmt2 definitely have the same value.
* Any global fields can be modified between stmt1 and stmt2.
*
* @param id_map
* Same as in same_statements(root1, root2, id_map).
*/
bool same_value(
Stmt *stmt1,
Stmt *stmt2,
Expand Down
6 changes: 6 additions & 0 deletions taichi/program/async_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,12 @@ struct TaskMeta {
void print() const;
};

// A wrapper class for the parameter in bool same_value() in analysis.h.
class AsyncStateSet {
public:
std::unordered_set<AsyncState> s;
};

class IRBank;

TaskMeta *get_task_meta(IRBank *bank, const TaskLaunchRecord &t);
Expand Down
10 changes: 9 additions & 1 deletion taichi/program/ir_bank.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,15 @@ IRHandle IRBank::fuse(IRHandle handle_a, IRHandle handle_b, Kernel *kernel) {
return result;
}

TI_TRACE("Begin uncached fusion");
TI_TRACE("Begin uncached fusion: [{}(size={})] <- [{}(size={})]",
handle_a.ir()->get_kernel()->name,
(handle_a.ir()->as<OffloadedStmt>()->has_body()
? handle_a.ir()->as<OffloadedStmt>()->body->size()
: -1),
handle_b.ir()->get_kernel()->name,
(handle_b.ir()->as<OffloadedStmt>()->has_body()
? handle_a.ir()->as<OffloadedStmt>()->body->size()
: -1));
// We are about to change both |task_a| and |task_b|. Clone them first.
auto cloned_task_a = handle_a.clone();
auto cloned_task_b = handle_b.clone();
Expand Down
29 changes: 20 additions & 9 deletions taichi/program/state_flow_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -654,12 +654,22 @@ std::unordered_set<int> StateFlowGraph::fuse_range(int begin, int end) {

auto edge_fusible = [&](int a, int b) {
TI_PROFILER("edge_fusible");
TI_ASSERT(a >= 0 && a < nodes.size() && b >= 0 && b < nodes.size());
// Check if a and b are fusible if there is an edge (a, b).
if (fused[a] || fused[b] || !fusion_meta[a].fusible ||
fusion_meta[a] != fusion_meta[b]) {
return false;
}
if (nodes[a]->meta->type != OffloadedTaskType::serial) {
std::unordered_map<int, int> offload_map;
offload_map[0] = 0;
std::unordered_set<AsyncState> modified_states =
get_task_meta(ir_bank_, nodes[a]->rec)->output_states;
std::unordered_set<AsyncState> modified_states_b =
get_task_meta(ir_bank_, nodes[a]->rec)->output_states;
modified_states.insert(modified_states_b.begin(),
modified_states_b.end());
AsyncStateSet modified_states_set{modified_states};
for (auto state_iter = nodes[a]->output_edges.get_state_iterator();
!state_iter.done(); ++state_iter) {
auto state = state_iter.get_state();
Expand Down Expand Up @@ -689,11 +699,10 @@ std::unordered_set<int> StateFlowGraph::fuse_range(int begin, int end) {
TI_ASSERT(nodes[b]->rec.stmt()->id == 0);
// Only map the OffloadedStmt to see if both SNodes are loop-unique
// on the same statement.
std::unordered_map<int, int> offload_map;
offload_map[0] = 0;
for (int i = 0; i < (int)ptr1->indices.size(); i++) {
if (!irpass::analysis::same_value(
ptr1->indices[i], ptr2->indices[i],
ptr1->indices[i], ptr2->indices[i], modified_states_set,
ir_bank_,
std::make_optional<std::unordered_map<int, int>>(
offload_map))) {
return false;
Expand Down Expand Up @@ -781,12 +790,14 @@ std::unordered_set<int> StateFlowGraph::fuse_range(int begin, int end) {
// Fuse no more than one task into task i
bool i_updated = false;
for (auto &edge : nodes[i]->output_edges.get_all_edges()) {
const int j = edge.second->pending_node_id - begin;
if (j != -1 && edge_fusible(i, j)) {
do_fuse(i, j);
// Iterators of nodes[i]->output_edges may be invalidated
i_updated = true;
break;
if (edge.second->pending()) {
const int j = edge.second->pending_node_id - begin;
if (j >= 0 && j < nodes.size() && edge_fusible(i, j)) {
do_fuse(i, j);
// Iterators of nodes[i]->output_edges may be invalidated
i_updated = true;
break;
}
}

if (i_updated) {
Expand Down
7 changes: 2 additions & 5 deletions tests/python/test_bit_array_vectorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,21 @@ def test_vectorized_struct_for():
@ti.kernel
def init():
for i, j in ti.ndrange((boundary_offset, N - boundary_offset),
(boundary_offset, N - boundary_offset)):
(boundary_offset, N - boundary_offset)):
x[i, j] = ti.random(dtype=ti.i32) % 2


@ti.kernel
def assign_vectorized():
ti.bit_vectorize(32)
for i, j in x:
y[i, j] = x[i, j]


@ti.kernel
def verify():
for i, j in ti.ndrange((boundary_offset, N - boundary_offset),
(boundary_offset, N - boundary_offset)):
(boundary_offset, N - boundary_offset)):
assert y[i, j] == x[i, j]


init()
assign_vectorized()
verify()