Skip to content

Commit

Permalink
Add PrioritizeLoad to CudaCodeGen. (pytorch#179)
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-xq authored Feb 20, 2020
1 parent ea1e2ad commit 69fc6ac
Showing 1 changed file with 117 additions and 1 deletion.
118 changes: 117 additions & 1 deletion torch/csrc/jit/tensorexpr/cuda_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,119 @@ void CudaPrinter::visit(const IfThenElse* v) {
v->false_value().accept(this);
}

class PrioritizeLoad : public IRMutator {
public:
virtual Expr mutate(const Load* v) {
MemLoadList& load_list = load_stack_.back();
Var load_new_var{"v", v->dtype()};
Expr new_value = IRMutator::mutate(v);
load_list.push_back(std::make_pair(load_new_var.node(), new_value));
return load_new_var;
}

// TODO: merge this with the IRMutator::mutate version.
virtual Stmt mutate(const For* v) {
Var var = v->var();
Expr start = v->start();
Expr stop = v->stop();
Stmt body = v->body();
LoopOptions loop_options = v->loop_options();
Expr var_new_expr = var.accept_mutator(this);
Var var_new = Var(var_new_expr.AsNode<Variable>());
Expr start_new = start.accept_mutator(this);
Expr stop_new = stop.accept_mutator(this);
PushList();
Stmt body_new = body.accept_mutator(this);
Stmt body_with_loads = AddMemLoadsFromList(body_new);
PopList();
if (same_node(var, var_new) && same_node(start, start_new) &&
same_node(stop, stop_new) && same_node(body, body_with_loads)) {
return Stmt(v);
}
return For::make(
var_new, start_new, stop_new, body_with_loads, loop_options);
}

virtual Stmt mutate(const LetStmt* v) {
Var var = v->var();
Expr value = v->value();
Stmt body = v->body();
Expr var_new_expr = var.accept_mutator(this);
Variable* var_new_ptr = var_new_expr.AsNode<Variable>();
if (var_new_ptr == nullptr) {
throw std::runtime_error("LetStmt var must be variable");
}
Var var_new{var_new_ptr};
Expr value_new = value.accept_mutator(this);
PushList();
Stmt body_new = body.accept_mutator(this);
Stmt body_with_loads = AddMemLoadsFromList(body_new);
PopList();
if (same_node(var, var_new) && same_node(value, value_new) &&
same_node(body, body_with_loads)) {
return Stmt(v);
}
return LetStmt::make(var_new, value_new, body_with_loads);
}

virtual Stmt mutate(const Cond* v) {
Expr cond_old = v->condition();
Stmt true_old = v->true_stmt();
Stmt false_old = v->false_stmt();

Expr cond_new = cond_old.accept_mutator(this);
PushList();
Stmt true_new = true_old.accept_mutator(this);
Stmt true_with_loads = AddMemLoadsFromList(true_new);
PopList();
PushList();
Stmt false_new = false_old.accept_mutator(this);
Stmt false_with_loads = AddMemLoadsFromList(false_new);
PopList();

if (same_node(cond_old, cond_new) && same_node(true_old, true_with_loads) &&
same_node(false_old, false_with_loads)) {
return Stmt(v);
}
return Cond::make(cond_new, true_with_loads, false_with_loads);
}

Stmt Process(const Stmt& stmt) {
this->PushList();
Stmt stmt_v = stmt;
Stmt stmt_new = stmt_v.accept_mutator(this);
Stmt stmt_with_loads = AddMemLoadsFromList(stmt_new);
this->PopList();
return stmt_with_loads;
}

private:
using MemLoadEntry = std::pair<const Variable*, Expr>;
using MemLoadList = std::vector<MemLoadEntry>;
using MemoryLoadStack = std::vector<MemLoadList>;

void PushList() {
load_stack_.push_back(MemLoadList());
}

void PopList() {
load_stack_.pop_back();
}

Stmt AddMemLoadsFromList(const Stmt& stmt) {
MemLoadList& load_list = load_stack_.back();
Stmt stmt_v = stmt;
for (int i = load_list.size() - 1; i >= 0; i--) {
const MemLoadEntry& entry = load_list[i];
Variable* var_ptr = const_cast<Variable*>(entry.first);
stmt_v = LetStmt::make(Var(var_ptr), entry.second, stmt_v);
}
return stmt_v;
}

MemoryLoadStack load_stack_;
};

void CudaCodeGen::Initialize() {
printer_.reset(new CudaPrinter(&oss_));
// TODO: handle multiple kernels.
Expand All @@ -209,7 +322,10 @@ void CudaCodeGen::Initialize() {
os() << ") {";

os() << std::endl;
stmt().accept(printer_.get());
Stmt stmt_v = stmt();
PrioritizeLoad prioritize_load;
stmt_v = prioritize_load.Process(stmt_v);
stmt_v.accept(printer_.get());
os() << std::endl;
os() << "}";

Expand Down

0 comments on commit 69fc6ac

Please sign in to comment.