-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CODEGEN] Add LoweredFunc, MakeAPI and SplitHostDevice
- Loading branch information
Showing
25 changed files
with
1,345 additions
and
348 deletions.
There are no files selected for viewing
Submodule HalideIR
updated
from adfa66 to d82003
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
/*! | ||
* Copyright (c) 2016 by Contributors | ||
* \file codegen.h | ||
* \brief Collection of Lowlevel IR pass to codegen. | ||
*/ | ||
#ifndef TVM_CODEGEN_H_ | ||
#define TVM_CODEGEN_H_ | ||
|
||
#include <string> | ||
#include "./base.h" | ||
#include "./expr.h" | ||
#include "./module.h" | ||
|
||
namespace tvm { | ||
/*! \brief namespace for lowlevel IR pass and codegen */ | ||
namespace codegen { | ||
/*! | ||
* \brief Make an user callable API LoweredFunc. | ||
* | ||
* The main task of this function is to create code to : | ||
* - Map the values in the api_args to of Var that is required by body. | ||
* - Insert assertions to check type/value of the passed arguments. | ||
* | ||
* \param body The body of the function. | ||
* \param name The name of the function. | ||
* \param api_args Arguments to the function, can be either Var, or Buffer | ||
* \param num_packed_args Number of arguments that are processed in packed form. | ||
* \return a LoweredFunc with the specified signiture. | ||
* | ||
* \note | ||
* The function signiture have two cases | ||
* | ||
* if num_packed_args is zero: | ||
* f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args) | ||
* | ||
* if num_packed_args is not zero: | ||
* f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args, | ||
* api_arg_k, api_arg_k+1, ... api_arg_n) | ||
* | ||
* where n == len(api_args), k == num_packed_args | ||
* | ||
* There is no thread_axis in generated function. | ||
*/ | ||
LoweredFunc MakeAPI(Stmt body, | ||
std::string name, | ||
Array<NodeRef> api_args, | ||
int num_packed_args); | ||
|
||
/*! | ||
* \brief Count number of undefined vars in f. | ||
* \param f The function to be checked. | ||
* \return Number of undefined vars. | ||
*/ | ||
Array<Var> UndefinedVars(const LoweredFunc& f); | ||
|
||
/*! | ||
* \brief Split the function into a host function and device functions. | ||
* \param func The function to be splitted. | ||
* | ||
* \return Array of functions, the first one is host function, | ||
* the others are device functions. | ||
*/ | ||
Array<LoweredFunc> SplitHostDevice(LoweredFunc func); | ||
|
||
} // namespace codegen | ||
} // namespace tvm | ||
|
||
#endif // TVM_CODEGEN_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
/*! | ||
* Copyright (c) 2016 by Contributors | ||
* \file module.h | ||
* \brief Low level IR module, | ||
* Contains lowered function information. | ||
*/ | ||
#ifndef TVM_MODULE_H_ | ||
#define TVM_MODULE_H_ | ||
|
||
#include <tvm/container.h> | ||
#include <ir/FunctionBase.h> | ||
#include <string> | ||
|
||
#include "./base.h" | ||
#include "./expr.h" | ||
#include "./tensor.h" | ||
|
||
namespace tvm { | ||
|
||
// Internal node container of lowered function. | ||
class LoweredFuncNode; | ||
|
||
// Internal node container of module. | ||
class ModuleNode; | ||
|
||
/*! | ||
* \brief LoweredFunc represents function after lowering. | ||
* This is the final IR representation before codegen. | ||
*/ | ||
class LoweredFunc : public FunctionRef { | ||
public: | ||
LoweredFunc() {} | ||
explicit LoweredFunc(std::shared_ptr<Node> n) : FunctionRef(n) {} | ||
/*! | ||
* \brief access the internal node container | ||
* \return the pointer to the internal node container | ||
*/ | ||
inline const LoweredFuncNode* operator->() const; | ||
/*! \brief specify container node */ | ||
using ContainerType = LoweredFuncNode; | ||
}; | ||
|
||
/*! \brief Node container of LoweredFunc */ | ||
class LoweredFuncNode : public FunctionBaseNode { | ||
public: | ||
/*! \brief The name of the function */ | ||
std::string name; | ||
/*! | ||
* \brief The arguments of the function | ||
* This function can only take pod type(int, float) and void* as arguments. | ||
*/ | ||
Array<Var> args; | ||
/*! | ||
* \brief The IterVar axis of threads | ||
* Each axis need host function to specify a size. | ||
* \note Calling convention into LoweredFunc | ||
* | ||
* Assume we have a LoweredFunc f, a call into f | ||
* Call(f, arg1, arg2, ..., arg_n, | ||
* size_axis_1, size_axis_2, ... size_axis_m) | ||
* | ||
* Here n = len(args), m = len(thread_axis) | ||
* | ||
* The CodeGen should take this and translate this call | ||
* to corresponding API specific kernel launchs or function calls. | ||
*/ | ||
Array<IterVar> thread_axis; | ||
/*! | ||
* \brief The hint data type of Var handles defined in LetStmt | ||
* Can be used as hint when generating type signiture. | ||
* The creation rule is given by | ||
* handle_data_type[var_handle] = make_const(the_type, 0); | ||
* | ||
* \note Expr is used instead Type, because Type cannot be hold by Map. | ||
* constant Expr of given type is used. | ||
*/ | ||
Map<Var, Expr> handle_data_type; | ||
/*! \brief The body statment of the function */ | ||
Stmt body; | ||
/*! \return name of the operation */ | ||
const std::string& func_name() const final { | ||
return name; | ||
} | ||
// there is no return value, but return 1 | ||
// to enable Call into this function. | ||
int num_outputs() const final { | ||
return 1; | ||
} | ||
void VisitAttrs(AttrVisitor* v) final { | ||
v->Visit("name", &name); | ||
v->Visit("args", &args); | ||
v->Visit("thread_axis", &thread_axis); | ||
v->Visit("handle_data_type", &handle_data_type); | ||
v->Visit("body", &body); | ||
} | ||
|
||
static constexpr const char* _type_key = "LoweredFunc"; | ||
TVM_DECLARE_NODE_TYPE_INFO(LoweredFuncNode); | ||
}; | ||
|
||
// Implementations of inline functions | ||
inline const LoweredFuncNode* LoweredFunc::operator->() const { | ||
return static_cast<const LoweredFuncNode*>(node_.get()); | ||
} | ||
|
||
} // namespace tvm | ||
|
||
#endif // TVM_MODULE_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.