Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[async] [IR] More accurate same_value analysis #2118

Merged
merged 11 commits into from
Dec 27, 2020
142 changes: 121 additions & 21 deletions taichi/analysis/same_statements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
#include "taichi/ir/analysis.h"
#include "taichi/ir/statements.h"
#include "taichi/ir/visitors.h"
#include "taichi/program/async_utils.h"
#include "taichi/program/ir_bank.h"
#include <unordered_map>
#include <unordered_set>
#include <utility>

TLANG_NAMESPACE_BEGIN

Expand All @@ -16,25 +17,52 @@ class IRNodeComparator : public IRVisitor {
std::unordered_map<int, int> id_map;

bool recursively_check_;

// Compare if two IRNodes definitely have the same value instead.
// When this is true, it's weaker in the sense that we don't require the
// activate field in the GlobalPtrStmt to be the same, but stronger in the
// sense that we require the value to be the same (especially stronger in
// GlobalLoadStmt, RandStmt, etc.).
bool check_same_value_;
xumingkuan marked this conversation as resolved.
Show resolved Hide resolved

std::unordered_set<AsyncState> possibly_modified_states_;
bool all_states_can_be_modified_;
IRBank *ir_bank_;

public:
bool same;

explicit IRNodeComparator(IRNode *other_node,
std::optional<std::unordered_map<int, int>> id_map,
bool check_same_value)
explicit IRNodeComparator(
IRNode *other_node,
const std::optional<std::unordered_map<int, int>> &id_map,
bool check_same_value,
const std::optional<std::unordered_set<AsyncState>>
&possibly_modified_states,
IRBank *ir_bank)
: other_node(other_node) {
allow_undefined_visitor = true;
invoke_default_visitor = true;
same = true;
if (id_map.has_value()) {
recursively_check_ = true;
this->id_map = std::move(id_map.value());
this->id_map = id_map.value();
} else {
recursively_check_ = false;
}
if (possibly_modified_states.has_value()) {
TI_ASSERT_INFO(check_same_value,
"The parameter possibly_modified_states "
"is only supported when check_same_value is true");
TI_ASSERT_INFO(ir_bank,
"The parameter possibly_modified_states "
"requires ir_bank")
all_states_can_be_modified_ = false;
this->possibly_modified_states_ = possibly_modified_states.value();
} else {
all_states_can_be_modified_ = true;
}
check_same_value_ = check_same_value;
ir_bank_ = ir_bank;
}

void map_id(int this_id, int other_id) {
Expand Down Expand Up @@ -97,22 +125,68 @@ class IRNodeComparator : public IRVisitor {
same = false;
return;
}

// If two identical statements can have different values, return false.
if (check_same_value_ && !stmt->is_container_statement() &&
!stmt->common_statement_eliminable()) {
same = false;
auto other = other_node->as<Stmt>();
if (stmt == other) {
return;
}

// If two identical statements can have different values, return false.
// TODO: actually the condition should be "can stmt be an operand of
// another statement?"
const bool stmt_has_value = !stmt->is_container_statement();
// TODO: We want to know if two identical statements of the type same as
// stmt can have different values. In most cases, this property is the
// same as Stmt::common_statement_eliminable(). However, two identical
// GlobalPtrStmts cannot have different values, although
// GlobalPtrStmt::common_statement_eliminable() is false.
const bool identical_stmts_can_have_different_value =
stmt_has_value && !stmt->common_statement_eliminable() &&
!stmt->is<GlobalPtrStmt>();
// Note that we do not need to test !stmt2->common_statement_eliminable()
// because if this condition does not hold,
// same_statements(stmt1, stmt2) returns false anyway.
// same_value(stmt1, stmt2) returns false anyway.
if (check_same_value_ && identical_stmts_can_have_different_value) {
if (all_states_can_be_modified_) {
same = false;
return;
} else {
bool same_value = false;
if (auto global_load = stmt->cast<GlobalLoadStmt>()) {
if (auto global_ptr = global_load->ptr->cast<GlobalPtrStmt>()) {
TI_ASSERT(global_ptr->width() == 1);
if (possibly_modified_states_.count(ir_bank_->get_async_state(
global_ptr->snodes[0], AsyncState::Type::value)) == 0) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, previous we were too conservative, and returned "not the same" even for snodes that are only read in the tasks?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly.

same_value = true;
}
}
// TODO: other cases?
}
if (!same_value) {
same = false;
return;
}
}
}

// field check
auto other = other_node->as<Stmt>();
if (!stmt->field_manager.equal(other->field_manager)) {
same = false;
return;
if (check_same_value_ && stmt->is<GlobalPtrStmt>()) {
// Special case: we do not care the "activate" field when checking
// whether two global pointers share the same value.
// And we cannot use irpass::analysis::definitely_same_address()
// directly because that function does not support id_map.

// TODO: Update this part if GlobalPtrStmt comes to have more fields
TI_ASSERT(stmt->width() == 1);
if (stmt->as<GlobalPtrStmt>()->snodes[0]->id !=
other->as<GlobalPtrStmt>()->snodes[0]->id) {
same = false;
return;
}
} else {
// field check
if (!stmt->field_manager.equal(other->field_manager)) {
same = false;
return;
}
}

// operand check
Expand Down Expand Up @@ -219,8 +293,15 @@ class IRNodeComparator : public IRVisitor {
static bool run(IRNode *root1,
IRNode *root2,
const std::optional<std::unordered_map<int, int>> &id_map,
bool check_same_value) {
IRNodeComparator comparator(root2, id_map, check_same_value);
bool check_same_value,
const std::optional<std::unordered_set<AsyncState>>
xumingkuan marked this conversation as resolved.
Show resolved Hide resolved
&possibly_modified_states,
IRBank *ir_bank) {
// We need to distinguish the case of an empty
// std::unordered_set<AsyncState> (assuming every SNodes are unchanged)
// and empty (assuming nothing), so we use std::optional<> here.
IRNodeComparator comparator(root2, id_map, check_same_value,
possibly_modified_states, ir_bank);
root1->accept(&comparator);
return comparator.same;
}
Expand Down Expand Up @@ -270,17 +351,36 @@ bool same_statements(
if (!root1 || !root2)
return false;
return IRNodeComparator::run(root1, root2, id_map,
/*check_same_value=*/false);
/*check_same_value=*/false, std::nullopt,
/*ir_bank=*/nullptr);
}
bool same_value(Stmt *stmt1,
Stmt *stmt2,
const AsyncStateSet &possibly_modified_states,
IRBank *ir_bank,
const std::optional<std::unordered_map<int, int>> &id_map) {
// Test if two statements definitely have the same value.
if (stmt1 == stmt2)
return true;
if (!stmt1 || !stmt2)
return false;
return IRNodeComparator::run(
stmt1, stmt2, id_map, /*check_same_value=*/true,
std::make_optional<std::unordered_set<AsyncState>>(
possibly_modified_states.s),
ir_bank);
}
bool same_value(Stmt *stmt1,
Stmt *stmt2,
const std::optional<std::unordered_map<int, int>> &id_map) {
// Test if two statements must have the same value.
// Test if two statements definitely have the same value.
if (stmt1 == stmt2)
return true;
if (!stmt1 || !stmt2)
return false;
return IRNodeComparator::run(stmt1, stmt2, id_map, /*check_same_value=*/true);
return IRNodeComparator::run(stmt1, stmt2, id_map,
/*check_same_value=*/true, std::nullopt,
/*ir_bank=*/nullptr);
}
} // namespace irpass::analysis

Expand Down
41 changes: 33 additions & 8 deletions taichi/ir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class ControlFlowGraph;

struct TaskMeta;
class IRBank;
class AsyncStateSet;

// IR Analysis
namespace irpass::analysis {
Expand Down Expand Up @@ -82,20 +83,44 @@ std::vector<Stmt *> get_store_destination(Stmt *store_stmt);
bool has_store_or_atomic(IRNode *root, const std::vector<Stmt *> &vars);
std::pair<bool, Stmt *> last_store_or_atomic(IRNode *root, Stmt *var);
bool maybe_same_address(Stmt *var1, Stmt *var2);
/** Test if root1 and root2 are the same, i.e., have the same type,
* the same operands, the same fields, and the same containing statements.
/**
* Test if root1 and root2 are the same, i.e., have the same type,
* the same operands, the same fields, and the same containing statements.
*
* @param id_map
* If id_map is std::nullopt by default, two operands are considered
* the same if they have the same id and do not belong to either root,
* or they belong to root1 and root2 at the same position in the roots.
* Otherwise, this function also recursively check the operands until
* ids in the id_map are reached.
* @param id_map
* If id_map is std::nullopt by default, two operands are considered
* the same if they have the same id and do not belong to either root,
* or they belong to root1 and root2 at the same position in the roots.
* Otherwise, this function also recursively check the operands until
* ids in the id_map are reached.
*/
bool same_statements(
IRNode *root1,
IRNode *root2,
const std::optional<std::unordered_map<int, int>> &id_map = std::nullopt);
/**
* Test if stmt1 and stmt2 definitely have the same value.
*
* @param possibly_modified_states
* Assumes that only states in possibly_modified_states can be modified
* between stmt1 and stmt2.
*
* @param id_map
* Same as in same_statements(root1, root2, id_map).
*/
bool same_value(
Stmt *stmt1,
Stmt *stmt2,
const AsyncStateSet &possibly_modified_states,
IRBank *ir_bank,
const std::optional<std::unordered_map<int, int>> &id_map = std::nullopt);
/**
* Test if stmt1 and stmt2 definitely have the same value.
* Any global fields can be modified between stmt1 and stmt2.
*
* @param id_map
* Same as in same_statements(root1, root2, id_map).
*/
bool same_value(
Stmt *stmt1,
Stmt *stmt2,
Expand Down
6 changes: 6 additions & 0 deletions taichi/program/async_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,12 @@ struct TaskMeta {
void print() const;
};

// A wrapper class for the parameter in bool same_value() in analysis.h.
class AsyncStateSet {
public:
std::unordered_set<AsyncState> s;
};

class IRBank;

TaskMeta *get_task_meta(IRBank *bank, const TaskLaunchRecord &t);
Expand Down
10 changes: 9 additions & 1 deletion taichi/program/ir_bank.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,15 @@ IRHandle IRBank::fuse(IRHandle handle_a, IRHandle handle_b, Kernel *kernel) {
return result;
}

TI_TRACE("Begin uncached fusion");
TI_TRACE("Begin uncached fusion: [{}(size={})] <- [{}(size={})]",
handle_a.ir()->get_kernel()->name,
(handle_a.ir()->as<OffloadedStmt>()->has_body()
? handle_a.ir()->as<OffloadedStmt>()->body->size()
: -1),
handle_b.ir()->get_kernel()->name,
(handle_b.ir()->as<OffloadedStmt>()->has_body()
? handle_a.ir()->as<OffloadedStmt>()->body->size()
: -1));
// We are about to change both |task_a| and |task_b|. Clone them first.
auto cloned_task_a = handle_a.clone();
auto cloned_task_b = handle_b.clone();
Expand Down
29 changes: 20 additions & 9 deletions taichi/program/state_flow_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -654,12 +654,22 @@ std::unordered_set<int> StateFlowGraph::fuse_range(int begin, int end) {

auto edge_fusible = [&](int a, int b) {
TI_PROFILER("edge_fusible");
TI_ASSERT(a >= 0 && a < nodes.size() && b >= 0 && b < nodes.size());
// Check if a and b are fusible if there is an edge (a, b).
if (fused[a] || fused[b] || !fusion_meta[a].fusible ||
fusion_meta[a] != fusion_meta[b]) {
return false;
}
if (nodes[a]->meta->type != OffloadedTaskType::serial) {
std::unordered_map<int, int> offload_map;
offload_map[0] = 0;
std::unordered_set<AsyncState> modified_states =
get_task_meta(ir_bank_, nodes[a]->rec)->output_states;
std::unordered_set<AsyncState> modified_states_b =
get_task_meta(ir_bank_, nodes[a]->rec)->output_states;
modified_states.insert(modified_states_b.begin(),
modified_states_b.end());
AsyncStateSet modified_states_set{modified_states};
for (auto state_iter = nodes[a]->output_edges.get_state_iterator();
!state_iter.done(); ++state_iter) {
auto state = state_iter.get_state();
Expand Down Expand Up @@ -689,11 +699,10 @@ std::unordered_set<int> StateFlowGraph::fuse_range(int begin, int end) {
TI_ASSERT(nodes[b]->rec.stmt()->id == 0);
// Only map the OffloadedStmt to see if both SNodes are loop-unique
// on the same statement.
std::unordered_map<int, int> offload_map;
offload_map[0] = 0;
for (int i = 0; i < (int)ptr1->indices.size(); i++) {
if (!irpass::analysis::same_value(
ptr1->indices[i], ptr2->indices[i],
ptr1->indices[i], ptr2->indices[i], modified_states_set,
ir_bank_,
std::make_optional<std::unordered_map<int, int>>(
offload_map))) {
return false;
Expand Down Expand Up @@ -781,12 +790,14 @@ std::unordered_set<int> StateFlowGraph::fuse_range(int begin, int end) {
// Fuse no more than one task into task i
bool i_updated = false;
for (auto &edge : nodes[i]->output_edges.get_all_edges()) {
const int j = edge.second->pending_node_id - begin;
if (j != -1 && edge_fusible(i, j)) {
do_fuse(i, j);
// Iterators of nodes[i]->output_edges may be invalidated
i_updated = true;
break;
if (edge.second->pending()) {
const int j = edge.second->pending_node_id - begin;
if (j >= 0 && j < nodes.size() && edge_fusible(i, j)) {
do_fuse(i, j);
// Iterators of nodes[i]->output_edges may be invalidated
i_updated = true;
break;
}
}

if (i_updated) {
Expand Down
7 changes: 2 additions & 5 deletions tests/python/test_bit_array_vectorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,21 @@ def test_vectorized_struct_for():
@ti.kernel
def init():
for i, j in ti.ndrange((boundary_offset, N - boundary_offset),
(boundary_offset, N - boundary_offset)):
(boundary_offset, N - boundary_offset)):
x[i, j] = ti.random(dtype=ti.i32) % 2


@ti.kernel
def assign_vectorized():
ti.bit_vectorize(32)
for i, j in x:
y[i, j] = x[i, j]


@ti.kernel
def verify():
for i, j in ti.ndrange((boundary_offset, N - boundary_offset),
(boundary_offset, N - boundary_offset)):
(boundary_offset, N - boundary_offset)):
assert y[i, j] == x[i, j]


init()
assign_vectorized()
verify()