Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] Move AttrStmt to HalideIR #21

Merged
merged 1 commit into from
Jan 20, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion HalideIR
Submodule HalideIR updated from af2a2f to b6637f
29 changes: 1 addition & 28 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,34 +49,6 @@ struct Reduce : public ExprNode<Reduce> {
static constexpr const char* Min = "Min";
};

/*!
* \brief Define certain auxiliary attribute for the body to be a symbolic value.
* This is used to insert hint(shape, storage, split) about certain scopes.
*/
struct AttrStmt : public StmtNode<AttrStmt> {
/*! \brief this is attribute about certain node */
NodeRef node;
/*! \brief the type key of the attribute */
std::string type_key;
/*! \brief The attribute value, value is well defined at current scope. */
Expr value;
/*! \brief The body statement to be executed */
Stmt body;

/*! \brief construct expr from name and rdom */
static Stmt make(NodeRef node, std::string type_key, Expr value, Stmt body);

void VisitAttrs(AttrVisitor* v) final {
v->Visit("node", &node);
v->Visit("type_key", &type_key);
v->Visit("value", &value);
v->Visit("body", &body);
}

static const IRNodeType _type_info = IRNodeType::ExtensionExpr;
static constexpr const char* _type_key = "AttrStmt";
};

// Reuse IR node defintiion from HalideIR
using Halide::Internal::IntImm;
using Halide::Internal::UIntImm;
Expand Down Expand Up @@ -106,6 +78,7 @@ using Halide::Internal::Broadcast;
using Halide::Internal::Call;
using Halide::Internal::Let;
using Halide::Internal::LetStmt;
using Halide::Internal::AttrStmt;
using Halide::Internal::AssertStmt;
using Halide::Internal::ProducerConsumer;
using Halide::Internal::For;
Expand Down
2 changes: 0 additions & 2 deletions src/c_api/c_api_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ using RetValue = APIVariantValue;

TVM_REGISTER_API(_format_str)
.set_body([](const ArgStack& args, RetValue *ret) {
using Halide::Internal::BaseExprNode;
using Halide::Internal::BaseStmtNode;
CHECK(args.at(0).type_id == kNodeHandle);
std::ostringstream os;
os << args.at(0).operator NodeRef();
Expand Down
1 change: 0 additions & 1 deletion src/c_api/c_api_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ using RetValue = APIVariantValue;

TVM_REGISTER_API(_const)
.set_body([](const ArgStack& args, RetValue *ret) {
using Halide::Internal::make_const;
if (args.at(0).type_id == kLong) {
*ret = make_const(args.at(1), args.at(0).operator int64_t());
} else if (args.at(0).type_id == kDouble) {
Expand Down
23 changes: 0 additions & 23 deletions src/lang/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,6 @@ void ExprNode<Reduce>::accept(IRVisitor *v, const Expr&) const {
LOG(FATAL) << "Reduce do not work with old Visitor, use IRFunctor style visitor";
}

template<>
void StmtNode<AttrStmt>::accept(IRVisitor *v, const Stmt&) const {
LOG(FATAL) << "AttrStmt do not work with old Visitor, use IRFunctor style visitor";
}

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Reduce>([](const Reduce *op, IRPrinter *p) {
p->stream << "reduce("
Expand All @@ -34,15 +29,6 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << ", rdom=" << op->rdom << ")";
});

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<AttrStmt>([](const AttrStmt *op, IRPrinter *p) {
p->do_indent();
p->stream << "// attr " << op->type_key << " = ";
p->print(op->value);
p->stream << '\n';
p->print(op->body);
});

} // namespace Internal
} // namespace Halide

Expand All @@ -62,15 +48,6 @@ Expr Reduce::make(std::string op, Expr source, Array<IterVar> rdom) {
return Expr(n);
}

Stmt AttrStmt::make(NodeRef node, std::string type_key, Expr value, Stmt body) {
auto n = std::make_shared<AttrStmt>();
n->node = node;
n->type_key = type_key;
n->value = value;
n->body = body;
return Stmt(n);
}

TVM_REGISTER_NODE_TYPE(Reduce);
TVM_REGISTER_NODE_TYPE(AttrStmt);

Expand Down
9 changes: 1 addition & 8 deletions src/pass/ir_mutator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,6 @@ IRMutator::FMutateStmt& IRMutator::vtable_stmt() { // NOLINT(*)
static FMutateStmt inst; return inst;
}

// namespace to register the functors.
namespace {

using namespace Halide::Internal;

// const expr
inline Expr ReturnSelfExpr(const NodeRef&, const Expr& e, IRMutator*) {
return e;
Expand Down Expand Up @@ -290,7 +285,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
return s;
})
.set_dispatch<Realize>([](const Realize *op, const Stmt& s, IRMutator* m) {
Region new_bounds;
Halide::Internal::Region new_bounds;
bool bounds_changed = false;

// Mutate the bounds
Expand Down Expand Up @@ -350,7 +345,5 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
return Evaluate::make(v);
}
});

} // namespace
} // namespace ir
} // namespace tvm