Skip to content

Commit

Permalink
[LLVM/RUNTIME] Support Parallel for on CPU (#54)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored Feb 26, 2017
1 parent 2f462cc commit f6c043e
Show file tree
Hide file tree
Showing 15 changed files with 298 additions and 86 deletions.
9 changes: 5 additions & 4 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,12 @@ LoweredFunc MakeAPI(Stmt body,
int num_unpacked_args);

/*!
* \brief Count number of undefined vars in f.
* \param f The function to be checked.
* \return Number of undefined vars.
* \brief Find undefined vars in the statment.
* \param stmt The function to be checked.
* \param defs The vars that is defined.
* \return Array of undefined vars.
*/
Array<Var> UndefinedVars(const LoweredFunc& f);
Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);

/*!
* \brief Split the function into a host function and device functions.
Expand Down
34 changes: 27 additions & 7 deletions include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,18 @@ TVM_DLL int TVMModPreCompile(TVMModuleHandle mod,
const char* func_name,
TVMContext ctx);

/*!
* \brief Free the Module
* \param mod The module to be freed.
*
* \note This may not free up the module's resources.
* If there is active TVMFunctionHandle uses the module
* Or if this module is imported by another active module.
*
* The all functions remains valid until TVMFuncFree is called.
*/
TVM_DLL int TVMModFree(TVMModuleHandle mod);

/*!
* \brief Backend function for modules to get function
* from its environment mod_node (its imports and global function).
Expand All @@ -242,17 +254,25 @@ TVM_DLL int TVMModPreCompile(TVMModuleHandle mod,
TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node,
const char* func_name,
TVMFunctionHandle *out);

/*!
* \brief Free the Module
* \param mod The module to be freed.
* \brief Backend function for running parallel for loop.
*
* \note This may not free up the module's resources.
* If there is active TVMFunctionHandle uses the module
* Or if this module is imported by another active module.
* \note This API is supposed to be used by backend,
* it is not supposed to be used by user.
*
* The all functions remains valid until TVMFuncFree is called.
* \param begin The start of iteration.
* \param end The end of iteration.
* \param lambda The lambda function to be executed.
* \param env The environment of lambda function.
*
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMModFree(TVMModuleHandle mod);
TVM_DLL int TVMBackendParallelFor(
int64_t begin,
int64_t end,
int (*lambda)(int64_t begin, int64_t end, void* env),
void* env);

/*!
* \brief Free the function when it is no longer needed.
Expand Down
9 changes: 8 additions & 1 deletion include/tvm/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ enum AttachType : int {
/*! \brief IterVar type */
enum IterVarType : int {
kUnrolled = 1,
kVectorized = 2
kVectorized = 2,
kParallel = 3
};

/*! \brief Stage, contains scheduling for a stage of computation. */
Expand Down Expand Up @@ -152,6 +153,12 @@ class Stage : public NodeRef {
* \return reference to self.
*/
Stage& unroll(IterVar var); // NOLINT(*)
/*!
* \brief Parallelize iteration.
* \param var The axis to be parallelized.
* \return reference to self.
*/
Stage& parallel(IterVar var); // NOLINT(*)
/*!
* \brief whether the stage has been scheduled.
* \return whether the stage has been scheduled.
Expand Down
10 changes: 10 additions & 0 deletions python/tvm/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,13 @@ def unroll(self, var):
The iteration to be unrolled.
"""
_api_internal._StageUnroll(self, var)

def parallel(self, var):
"""Parallelize the iteration.
Parameters
----------
var : IterVar
The iteration to be parallelized.
"""
_api_internal._StageParallel(self, var)
6 changes: 6 additions & 0 deletions src/api/api_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,12 @@ TVM_REGISTER_API(_StageVectorize)
.vectorize(args[1]);
});

TVM_REGISTER_API(_StageParallel)
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.parallel(args[1]);
});

TVM_REGISTER_API(_ScheduleNormalize)
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Schedule()
Expand Down
143 changes: 111 additions & 32 deletions src/codegen/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#ifdef TVM_LLVM_VERSION

#include <tvm/runtime/c_runtime_api.h>
#include <tvm/ir_pass.h>
#include "./codegen_llvm.h"
#include "../../arithmetic/compute_expr.h"

Expand All @@ -30,6 +31,7 @@ void CodeGenLLVM::Init(const std::string& module_name,
t_int8_ = llvm::Type::getInt8Ty(*ctx);
t_int16_ = llvm::Type::getInt16Ty(*ctx);
t_int32_ = llvm::Type::getInt32Ty(*ctx);
t_int64_ = llvm::Type::getInt64Ty(*ctx);
t_float64_ = llvm::Type::getDoubleTy(*ctx);
t_tvm_index_ = llvm::Type::getIntNTy(*ctx, sizeof(tvm_index_t) * 8);
t_tvm_context_ = llvm::StructType::create({t_int_, t_int_});
Expand All @@ -43,6 +45,8 @@ void CodeGenLLVM::Init(const std::string& module_name,
t_tvm_type_,
t_tvm_context_});
t_tvm_value_ = llvm::StructType::create({t_float64_});
t_f_tvm_par_for_lambda_ = llvm::FunctionType::get(
t_int_, {t_int64_, t_int64_, t_void_p_}, false);
md_builder_.reset(new llvm::MDBuilder(*ctx));
md_very_likely_branch_ =
md_builder_->createBranchWeights(1 << 30, 0);
Expand Down Expand Up @@ -70,7 +74,11 @@ void CodeGenLLVM::Init(const std::string& module_name,
f_tvm_api_set_last_error_ = llvm::Function::Create(
llvm::FunctionType::get(t_void_, {t_char_->getPointerTo()}, false),
llvm::Function::ExternalLinkage, "TVMAPISetLastError", module_.get());

f_tvm_parallel_for_ = llvm::Function::Create(
llvm::FunctionType::get(t_int_, {
t_int64_, t_int64_, t_f_tvm_par_for_lambda_->getPointerTo(), t_void_p_}
, false),
llvm::Function::ExternalLinkage, "TVMBackendParallelFor", module_.get());
this->InitTarget(target_triple);
// initialize builder
builder_.reset(new IRBuilder(*ctx));
Expand Down Expand Up @@ -141,7 +149,9 @@ void CodeGenLLVM::AddMainFunction(const std::string& entry_func_name) {
}
llvm::BasicBlock* block = llvm::BasicBlock::Create(*ctx_, "entry", function_);
builder_->SetInsertPoint(block);
builder_->CreateRet(builder_->CreateCall(f, args));
llvm::CallInst* call = builder_->CreateCall(f, args);
call->setTailCall(true);
builder_->CreateRet(call);
}

class FPassManager : public llvm::legacy::FunctionPassManager {
Expand Down Expand Up @@ -545,7 +555,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
return nullptr;
}

llvm::BasicBlock* CodeGenLLVM::CheckPackedCallSuccess(llvm::Value* retcode) {
llvm::BasicBlock* CodeGenLLVM::CheckCallSuccess(llvm::Value* retcode) {
// create emit codes that checks and load the function.
using llvm::BasicBlock;
BasicBlock* fail_block = BasicBlock::Create(
Expand All @@ -563,34 +573,15 @@ llvm::BasicBlock* CodeGenLLVM::CheckPackedCallSuccess(llvm::Value* retcode) {
return end_block;
}
void CodeGenLLVM::Visit_(const For* op) {
using llvm::BasicBlock;
BasicBlock* for_head = BasicBlock::Create(
*ctx_, "for_head", function_);
BasicBlock* for_body = BasicBlock::Create(
*ctx_, "for_body", function_);
BasicBlock* for_end = BasicBlock::Create(
*ctx_, "for_end", function_);
BasicBlock* pre_block = builder_->GetInsertBlock();
CHECK(is_zero(op->min));
Type t = op->min.type();
llvm::Value* init = ConstInt32(0);
llvm::Value* extent = MakeValue(op->extent);
builder_->CreateBr(for_head);

builder_->SetInsertPoint(for_head);
llvm::PHINode* index = builder_->CreatePHI(LLVMType(t), 2);
index->addIncoming(init, pre_block);
llvm::Value* cond = CreateLT(t, index, extent);
builder_->CreateCondBr(cond, for_body, for_end, md_very_likely_branch_);
// body of for
builder_->SetInsertPoint(for_body);
var_map_[op->loop_var.get()] = index;
this->Visit(op->body);
llvm::Value* next_index = CreateAdd(t, index, ConstInt32(1));
index->addIncoming(next_index, builder_->GetInsertBlock());
builder_->CreateBr(for_head);
// end of for
builder_->SetInsertPoint(for_end);
if (op->for_type == ForType::Serial) {
CreateSerialFor(ConstInt32(0), MakeValue(op->extent),
op->loop_var, op->body);
} else if (op->for_type == ForType::Parallel) {
CreateParallelFor(op);
} else {
LOG(FATAL) << "cannot handle for type " << op->for_type;
}
}

void CodeGenLLVM::Visit_(const IfThenElse* op) {
Expand Down Expand Up @@ -807,7 +798,7 @@ llvm::Value* CodeGenLLVM::GetPackedFuncHandle(const std::string& fname) {
llvm::Value* ctx = builder_->CreateLoad(gv_mod_ctx_);
llvm::Value* retcode = builder_->CreateCall(
f_tvm_get_func_from_env_, {ctx, GetConstString(fname), out});
init_block = CheckPackedCallSuccess(retcode);
init_block = CheckCallSuccess(retcode);
llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, align);
builder_->CreateBr(end_block);
// end block
Expand Down Expand Up @@ -846,7 +837,7 @@ llvm::Value* CodeGenLLVM::CreateCallPacked(const Call* op) {
}
llvm::Value* ret_value = builder_->CreateAlloca(t_tvm_value_);
llvm::Value* ret_tcode = builder_->CreateAlloca(t_int_);
CheckPackedCallSuccess(
CheckCallSuccess(
builder_->CreateCall(
f_tvm_func_call_,
{handle, targs, tcodes, ConstInt32(nargs), ret_value, ret_tcode}));
Expand Down Expand Up @@ -934,6 +925,94 @@ llvm::Value* CodeGenLLVM::GetConstString(const std::string& str) {
}
}

void CodeGenLLVM::CreateParallelFor(const For* op) {
using llvm::BasicBlock;
llvm::Value* min = MakeValue(op->min);
llvm::Value* extent = MakeValue(op->extent);
min = builder_->CreateIntCast(min, t_int64_, op->min.type().is_int());
extent = builder_->CreateIntCast(extent, t_int64_, op->min.type().is_int());
// fields to be packed into closure.
Var loop_var(op->loop_var.node_);
Array<Var> vfields = ir::UndefinedVars(op->body, {loop_var});
std::vector<llvm::Type*> fields;
for (Var v : vfields) {
auto it = var_map_.find(v.get());
CHECK(it != var_map_.end());
fields.push_back(it->second->getType());
}
// closure data
llvm::StructType* tcdata = llvm::StructType::create(fields);
llvm::Function* f = llvm::Function::Create(
t_f_tvm_par_for_lambda_,
llvm::Function::PrivateLinkage,
"__tvm_par_for_lambda", module_.get());
// allocate and setup the closure, call the closure.
llvm::Value* cdata = builder_->CreateAlloca(tcdata, ConstInt32(1));
llvm::Value* zero = ConstInt32(0);

for (size_t i = 0; i < vfields.size(); ++i) {
builder_->CreateStore(
var_map_.at(vfields[i].get()),
builder_->CreateInBoundsGEP(cdata, {zero, ConstInt32(i)}));
}
BasicBlock* par_for_end = CheckCallSuccess(
builder_->CreateCall(
f_tvm_parallel_for_,
{min, extent, f, builder_->CreatePointerCast(cdata, t_void_p_)}));
// Setup the closure function.
BasicBlock *lambda_entry = BasicBlock::Create(*ctx_, "entry", f);
builder_->SetInsertPoint(lambda_entry);
auto it = f->arg_begin();
llvm::Value* begin = &(*it++);
llvm::Value* end = &(*it++);
cdata = &(*it++);
begin = CreateCast(Int(64), op->loop_var.type(), begin);
end = CreateCast(Int(64), op->loop_var.type(), end);
cdata = builder_->CreatePointerCast(cdata, tcdata->getPointerTo());
// setup new variable map, swap it with current var context.
std::unordered_map<const Variable*, llvm::Value*> new_vmap;
for (size_t i = 0; i < vfields.size(); ++i) {
new_vmap[vfields[i].get()] =
builder_->CreateLoad(builder_->CreateInBoundsGEP(
cdata, {zero, ConstInt32(i)}));
}
std::swap(function_, f);
std::swap(new_vmap, var_map_);
CreateSerialFor(begin, end, op->loop_var, op->body);
builder_->CreateRet(ConstInt32(0));
// swap the var map back, now we are back on track.
std::swap(new_vmap, var_map_);
std::swap(function_, f);
builder_->SetInsertPoint(par_for_end);
}

void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end,
const VarExpr& loop_var, const Stmt& body) {
using llvm::BasicBlock;
Type t = loop_var.type();
BasicBlock* for_head = BasicBlock::Create(
*ctx_, "for_head", function_);
BasicBlock* for_body = BasicBlock::Create(
*ctx_, "for_body", function_);
BasicBlock* for_end = BasicBlock::Create(
*ctx_, "for_end", function_);
BasicBlock* pre_block = builder_->GetInsertBlock();
builder_->CreateBr(for_head);
builder_->SetInsertPoint(for_head);
llvm::PHINode* index = builder_->CreatePHI(begin->getType(), 2);
index->addIncoming(begin, pre_block);
llvm::Value* cond = CreateLT(t, index, end);
builder_->CreateCondBr(cond, for_body, for_end, md_very_likely_branch_);
// body of for
builder_->SetInsertPoint(for_body);
var_map_[loop_var.get()] = index;
this->Visit(body);
llvm::Value* next_index = CreateAdd(t, index, ConstInt32(1));
index->addIncoming(next_index, builder_->GetInsertBlock());
builder_->CreateBr(for_head);
// end of for
builder_->SetInsertPoint(for_end);
}
} // namespace codegen
} // namespace tvm
#endif // TVM_LLVM_VERSION
9 changes: 8 additions & 1 deletion src/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,12 @@ class CodeGenLLVM : public IRVisitor {
llvm::StructType* t_tvm_type_{nullptr};
llvm::StructType* t_tvm_array_{nullptr};
llvm::StructType* t_tvm_value_{nullptr};
llvm::FunctionType* t_f_tvm_par_for_lambda_{nullptr};
// tvm api functions
llvm::Function* f_tvm_func_call_{nullptr};
llvm::Function* f_tvm_get_func_from_env_{nullptr};
llvm::Function* f_tvm_api_set_last_error_{nullptr};
llvm::Function* f_tvm_parallel_for_{nullptr};
// The acting body
llvm::BasicBlock* block_{nullptr};
// Last value returned codegen call.
Expand All @@ -176,10 +178,15 @@ class CodeGenLLVM : public IRVisitor {
llvm::Value* CreateBufferPtr(Type t, llvm::Value* buffer, llvm::Value* index);
llvm::Value* CreateCast(Type from, Type to, llvm::Value* value);
llvm::Value* GetPackedFuncHandle(const std::string& str);
// Create parallel for.
void CreateParallelFor(const For* op);
// Create serial for
void CreateSerialFor(llvm::Value* begin, llvm::Value* end,
const VarExpr& loop_var, const Stmt& body);
// Check if the call to packed function is successful
// if not directly finalize function and pass on return code.
// return the end block after the check
llvm::BasicBlock* CheckPackedCallSuccess(llvm::Value* retcode);
llvm::BasicBlock* CheckCallSuccess(llvm::Value* retcode);
// Initialize target
void InitTarget(const std::string& target);
// Add a function to set global module context
Expand Down
16 changes: 16 additions & 0 deletions src/lang/lowered_func.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
/*!
* Copyright (c) 2017 by Contributors
* \file lowered_func.cc
*/
#include <tvm/lowered_func.h>

namespace tvm {

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<LoweredFuncNode>([](const LoweredFuncNode *op, IRPrinter *p) {
p->stream << "LoweredFunc(" << op->name << ", " << op << ")";
});

TVM_REGISTER_NODE_TYPE(LoweredFuncNode);

} // namespace tvm
2 changes: 1 addition & 1 deletion src/pass/make_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ LoweredFunc MakeAPI(Stmt body,
n->is_packed_func = num_unpacked_args == 0;
n->body = MergeNest({seq_init, seq_check}, body);
LoweredFunc f(n);
Array<Var> undefined = UndefinedVars(f);
Array<Var> undefined = UndefinedVars(f->body, f->args);
if (undefined.size() != 0) {
std::ostringstream os;
for (Var v : undefined) {
Expand Down
Loading

0 comments on commit f6c043e

Please sign in to comment.