Skip to content

Commit

Permalink
[CODEGEN] Add LoweredFunc, MakeAPI and SplitHostDevice
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Jan 22, 2017
1 parent 3c1020d commit 132ec47
Show file tree
Hide file tree
Showing 25 changed files with 1,345 additions and 348 deletions.
2 changes: 1 addition & 1 deletion HalideIR
Submodule HalideIR updated from adfa66 to d82003
3 changes: 3 additions & 0 deletions include/tvm/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ class Buffer : public NodeRef {
* \return the pointer to the internal node container
*/
inline const BufferNode* operator->() const;

/*! \brief specify container node */
using ContainerType = BufferNode;
};

/*! \brief Node to represent a buffer */
Expand Down
40 changes: 34 additions & 6 deletions include/tvm/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#endif

#include <stdint.h>
#include <stddef.h>


TVM_EXTERN_C {
Expand Down Expand Up @@ -216,18 +217,45 @@ TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from,
TVM_DLL int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream);

/*!
* \brief Launch a generated TVM function
* \brief TVM Function API: Get resource requirement
*
* By default TVM function try not to do internal allocations.
* Instead, TVMFuncRequirement can be called, given the input arguments.
*
* \param func function handle to be launched.
* \param args The arguments
* \param arg_type_ids The type id of the arguments
* \param num_args Number of arguments.
* \param out_workspace_size The workspace size needed to launch this function.
* \param out_workspace_align The alignment requirement of workspace.
*
* \note The data pointer in the arrays is not used by requirement.
*/
TVM_DLL int TVMFuncRequirement(TVMFunctionHandle func,
TVMArg* args,
int* arg_type_ids,
int num_args,
size_t* out_workspace_size,
size_t* out_workspace_align);

/*!
* \brief TVM Function API: Launch generated function.
*
* \param func function handle to be launched.
* \param args The arguments
* \param arg_type_ids The type id of the arguments
* \param num_args Number of arguments.
* \param stream The stream this function to be launched on.
* \param workspace Additional workspace used to launch this function.
*
* \sa TVMFuncRequirement
*/
TVM_DLL int TVMLaunch(TVMFunctionHandle func,
TVMArg* args,
int* arg_type_ids,
int num_args,
TVMStreamHandle stream);
TVM_DLL int TVMFuncLaunch(TVMFunctionHandle func,
TVMArg* args,
int* arg_type_ids,
int num_args,
TVMStreamHandle stream,
TVMArrayHandle workspace);
} // TVM_EXTERN_C

#endif // TVM_C_RUNTIME_API_H_
68 changes: 68 additions & 0 deletions include/tvm/codegen.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*!
* Copyright (c) 2016 by Contributors
* \file codegen.h
* \brief Collection of Lowlevel IR pass to codegen.
*/
#ifndef TVM_CODEGEN_H_
#define TVM_CODEGEN_H_

#include <string>
#include "./base.h"
#include "./expr.h"
#include "./module.h"

namespace tvm {
/*! \brief namespace for lowlevel IR pass and codegen */
namespace codegen {
/*!
* \brief Make an user callable API LoweredFunc.
*
* The main task of this function is to create code to :
* - Map the values in the api_args to of Var that is required by body.
* - Insert assertions to check type/value of the passed arguments.
*
* \param body The body of the function.
* \param name The name of the function.
* \param api_args Arguments to the function, can be either Var, or Buffer
* \param num_packed_args Number of arguments that are processed in packed form.
* \return a LoweredFunc with the specified signiture.
*
* \note
* The function signiture have two cases
*
* if num_packed_args is zero:
* f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args)
*
* if num_packed_args is not zero:
* f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args,
* api_arg_k, api_arg_k+1, ... api_arg_n)
*
* where n == len(api_args), k == num_packed_args
*
* There is no thread_axis in generated function.
*/
LoweredFunc MakeAPI(Stmt body,
std::string name,
Array<NodeRef> api_args,
int num_packed_args);

/*!
* \brief Count number of undefined vars in f.
* \param f The function to be checked.
* \return Number of undefined vars.
*/
Array<Var> UndefinedVars(const LoweredFunc& f);

/*!
* \brief Split the function into a host function and device functions.
* \param func The function to be splitted.
*
* \return Array of functions, the first one is host function,
* the others are device functions.
*/
Array<LoweredFunc> SplitHostDevice(LoweredFunc func);

} // namespace codegen
} // namespace tvm

#endif // TVM_CODEGEN_H_
42 changes: 42 additions & 0 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,48 @@ struct Reduce : public ExprNode<Reduce> {
static constexpr const char* Min = "Min";
};

/*! \brief namespace of TVM Intrinsic functions */
namespace intrinsic {
// Most of the intrinsics is to enab
/*!
* \brief See pesudo code
*
* Type tvm_api_load_arg(TVMArg* args, int* args_type_id, i) {
* assert(arg_type_id[i] == typeid(Type));
* return args[i];
* }
*/
constexpr const char* tvm_api_load_arg = "tvm_api_load_arg";
/*!
* \brief See pesudo code
*
* Type tvm_array_get_field(TVMArray* arr, int field_id) {
* return arr->field;
* }
* \sa TVMArrayFieldKind
*/
constexpr const char* tvm_array_get_field = "tvm_array_get_field";
/*!
* \brief See pesudo code
*
* bool tvm_handle_is_null(void* handle) {
* return handle == nullptr
* }
*/
constexpr const char* tvm_handle_is_null = "tvm_handle_is_null";

/*! \brief The field id of each field in array */
enum TVMArrayFieldKind {
kData = 0,
kNDim = 1,
kShape = 2,
kStrides = 3,
kTypeCode = 4,
kTypeBits = 5,
kTypeLanes = 6
};
} // namespace intrinsic

// Reuse IR node defintiion from HalideIR
using Halide::Internal::IntImm;
using Halide::Internal::UIntImm;
Expand Down
15 changes: 15 additions & 0 deletions include/tvm/ir_mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <tvm/ir_functor.h>
#include <unordered_map>
#include "./expr.h"
#include "./ir.h"

namespace tvm {
namespace ir {
Expand Down Expand Up @@ -51,6 +52,20 @@ class IRMutator {
static FMutateExpr& vtable_expr(); // NOLINT(*)
/*! \return internal stmt of expr */
static FMutateStmt& vtable_stmt(); // NOLINT(*)
// Set of overloadable functions
// The underscore allows Mutate not to be shadowed by inheritance
virtual Stmt Mutate_(const LetStmt* op, const Stmt& s);
virtual Stmt Mutate_(const AttrStmt* op, const Stmt& s);
virtual Stmt Mutate_(const For* op, const Stmt& s);
virtual Stmt Mutate_(const Provide* op, const Stmt& s);
virtual Stmt Mutate_(const Allocate* op, const Stmt& s);
virtual Stmt Mutate_(const Realize* op, const Stmt& s);
virtual Stmt Mutate_(const Store* op, const Stmt& s);
virtual Stmt Mutate_(const Free* op, const Stmt& s);
virtual Expr Mutate_(const Call* op, const Expr& e);
virtual Expr Mutate_(const Load* op, const Expr& s);
virtual Expr Mutate_(const Variable* op, const Expr& e);
virtual Expr Mutate_(const Let* op, const Expr& e);
};

/*!
Expand Down
7 changes: 6 additions & 1 deletion include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map);
*/
bool VerifySSA(const Stmt& ir);

/*!
* \brief Whether the expression have side effect.
* \return whether expression have side effect
*/
bool HasSideEffect(const Expr& e);

/*!
* \brief Convert a IR node to be SSA form.
* \param stmt The source statement to be converted.
Expand All @@ -79,7 +85,6 @@ Stmt Inline(Stmt stmt,
Array<Var> args,
Expr body);


/*!
* \brief Flatten the multi-dimensional read/write
* to single dimensional Load/Store
Expand Down
11 changes: 11 additions & 0 deletions include/tvm/ir_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@ class IRVisitor {
using FVisit = IRFunctor<void(const NodeRef&, IRVisitor*)>;
/*! \return internal vtable*/
static FVisit& vtable();
// overloadable visit function.
virtual void Visit_(const Variable* op);
virtual void Visit_(const AttrStmt* op);
virtual void Visit_(const LetStmt* op);
virtual void Visit_(const For* op);
virtual void Visit_(const Allocate* op);
virtual void Visit_(const Load* op);
virtual void Visit_(const Store* op);
virtual void Visit_(const Let* op);
virtual void Visit_(const Free* op);
virtual void Visit_(const Call* op);
};

/*!
Expand Down
108 changes: 108 additions & 0 deletions include/tvm/module.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*!
* Copyright (c) 2016 by Contributors
* \file module.h
* \brief Low level IR module,
* Contains lowered function information.
*/
#ifndef TVM_MODULE_H_
#define TVM_MODULE_H_

#include <tvm/container.h>
#include <ir/FunctionBase.h>
#include <string>

#include "./base.h"
#include "./expr.h"
#include "./tensor.h"

namespace tvm {

// Internal node container of lowered function.
class LoweredFuncNode;

// Internal node container of module.
class ModuleNode;

/*!
* \brief LoweredFunc represents function after lowering.
* This is the final IR representation before codegen.
*/
class LoweredFunc : public FunctionRef {
public:
LoweredFunc() {}
explicit LoweredFunc(std::shared_ptr<Node> n) : FunctionRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const LoweredFuncNode* operator->() const;
/*! \brief specify container node */
using ContainerType = LoweredFuncNode;
};

/*! \brief Node container of LoweredFunc */
class LoweredFuncNode : public FunctionBaseNode {
public:
/*! \brief The name of the function */
std::string name;
/*!
* \brief The arguments of the function
* This function can only take pod type(int, float) and void* as arguments.
*/
Array<Var> args;
/*!
* \brief The IterVar axis of threads
* Each axis need host function to specify a size.
* \note Calling convention into LoweredFunc
*
* Assume we have a LoweredFunc f, a call into f
* Call(f, arg1, arg2, ..., arg_n,
* size_axis_1, size_axis_2, ... size_axis_m)
*
* Here n = len(args), m = len(thread_axis)
*
* The CodeGen should take this and translate this call
* to corresponding API specific kernel launchs or function calls.
*/
Array<IterVar> thread_axis;
/*!
* \brief The hint data type of Var handles defined in LetStmt
* Can be used as hint when generating type signiture.
* The creation rule is given by
* handle_data_type[var_handle] = make_const(the_type, 0);
*
* \note Expr is used instead Type, because Type cannot be hold by Map.
* constant Expr of given type is used.
*/
Map<Var, Expr> handle_data_type;
/*! \brief The body statment of the function */
Stmt body;
/*! \return name of the operation */
const std::string& func_name() const final {
return name;
}
// there is no return value, but return 1
// to enable Call into this function.
int num_outputs() const final {
return 1;
}
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("args", &args);
v->Visit("thread_axis", &thread_axis);
v->Visit("handle_data_type", &handle_data_type);
v->Visit("body", &body);
}

static constexpr const char* _type_key = "LoweredFunc";
TVM_DECLARE_NODE_TYPE_INFO(LoweredFuncNode);
};

// Implementations of inline functions
inline const LoweredFuncNode* LoweredFunc::operator->() const {
return static_cast<const LoweredFuncNode*>(node_.get());
}

} // namespace tvm

#endif // TVM_MODULE_H_
6 changes: 6 additions & 0 deletions python/tvm/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,9 @@ class IterVar(NodeBase, _expr.ExprOp):
class Buffer(NodeBase):
"""Represent a Buffer in TVM."""
pass


@register_node
class LoweredFunc(NodeBase):
"""Represent a LoweredFunc in TVM."""
pass
3 changes: 2 additions & 1 deletion src/base/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#define TVM_BASE_COMMON_H_

#include <tvm/base.h>
#include <tvm/expr.h>
#include <string>

namespace tvm {
Expand All @@ -30,7 +31,7 @@ inline Type String2Type(std::string s) {
} else if (s.substr(0, 5) == "float") {
code = Type::Float; s = s.substr(5);
} else if (s == "handle") {
return Type(Type::Handle, 32, 1);
return Handle();
} else {
LOG(FATAL) << "unknown type " << s;
}
Expand Down
Loading

0 comments on commit 132ec47

Please sign in to comment.