Skip to content

Commit

Permalink
Remove activate node (PaddlePaddle#186)
Browse files Browse the repository at this point in the history
* remove activate op

* rename optim map_extern_call

* code clean
  • Loading branch information
Superjomn authored Aug 26, 2020
1 parent 7dcc4fe commit 2c48601
Show file tree
Hide file tree
Showing 14 changed files with 7 additions and 139 deletions.
4 changes: 0 additions & 4 deletions cinn/backends/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,6 @@ void CodeGenC::Visit(const ir::Not *op) {
IrPrinter::Print(op->v());
os() << ")";
}
void CodeGenC::Visit(const ir::Activate *op) {
// Should be replaced by a tanh function call.
CINN_NOT_IMPLEMENTED
}
void CodeGenC::Visit(const ir::Cast *op) { PrintCastExpr(op->type(), op->v()); }
void CodeGenC::Visit(const ir::For *op) {
os() << "for (";
Expand Down
4 changes: 0 additions & 4 deletions cinn/backends/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,6 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Add *op) {
return EmitBinaryOp(Visit(&op->a()), Visit(&op->b()), '+', is_integral_type(op->type()));
}

llvm::Value *CodeGenLLVM::Visit(const ir::Activate *) { // Should be replaced by a extern call.
CINN_NOT_IMPLEMENTED;
}

llvm::Value *CodeGenLLVM::Visit(const ir::Sub *op) {
return EmitBinaryOp(Visit(&op->a()), Visit(&op->b()), '-', is_integral_type(op->type()));
}
Expand Down
20 changes: 0 additions & 20 deletions cinn/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -368,26 +368,6 @@ struct Call : public ExprNode<Call> {
static const IrNodeTy _node_type_ = IrNodeTy::Call;
};

struct Activate : public UnaryOpNode<Activate> {
explicit Activate(Expr x) : UnaryOpNode<Activate>(x.type(), x) {}
enum class Kind {
kTanh,
kSigmoid,
kExp,
kCeil,
kFloor,
};
Kind kind;

static Expr Make(Kind kind, Expr arg) {
auto n = make_shared<Activate>(arg);
n->kind = kind;
return Expr(n);
}

static const IrNodeTy _node_type_ = IrNodeTy::Activate;
};

/**
* Variable used as iterator value or bound definition.
*/
Expand Down
5 changes: 0 additions & 5 deletions cinn/ir/ir_mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,6 @@ void IRMutator<T>::Visit(const Free *expr, T op) {
IRVisitorBase<void, T>::Visit(&node->destination, &node->destination);
}
template <typename T>
void IRMutator<T>::Visit(const Activate *expr, T op) {
auto *node = op->template As<Activate>();
IRVisitorBase<void, T>::Visit(&node->operand(0), &node->operand(0));
}
template <typename T>
void IRMutator<T>::Visit(const _Range_ *expr, T op) {}
template <typename T>
void IRMutator<T>::Visit(const _Buffer_ *expr, T op) {
Expand Down
25 changes: 0 additions & 25 deletions cinn/ir/ir_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,31 +37,6 @@ void IrPrinter::Visit(const GT *x) { PrintBinaryOp(">", x); }
void IrPrinter::Visit(const GE *x) { PrintBinaryOp(">=", x); }
void IrPrinter::Visit(const And *x) { PrintBinaryOp("and", x); }
void IrPrinter::Visit(const Or *x) { PrintBinaryOp("or", x); }
void IrPrinter::Visit(const Activate *x) {
switch (x->kind) {
case Activate::Kind::kTanh:
os() << "tanh";
break;
case Activate::Kind::kSigmoid:
os() << "sigmoid";
break;
case Activate::Kind::kCeil:
os() << "ceil";
break;
case Activate::Kind::kFloor:
os() << "floor";
break;
case Activate::Kind::kExp:
os() << "exp";
break;
default:
CINN_NOT_IMPLEMENTED
}

os() << "(";
Print(x->operand(0));
os() << ")";
}
void IrPrinter::Visit(const Not *x) {
os_ << "!";
Print(x->v());
Expand Down
56 changes: 0 additions & 56 deletions cinn/ir/ir_visitor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,62 +9,6 @@
namespace cinn {
namespace ir {

/*
namespace {
struct IrNodesCollector : public IRVisitor {
using teller_t = std::function<bool(const Expr*)>;
using handler_t = std::function<void(const Expr*)>;
teller_t teller;
handler_t handler;
IrNodesCollector(teller_t&& teller, handler_t&& handler) : teller(teller), handler(handler) {}
void Visit(const Expr* expr) override {
if (!expr->defined()) return;
if (visited_.count(expr->get())) return;
if (teller(expr)) {
handler(expr);
}
visited_.insert(expr->get());
switch (expr->node_type()) {
#define __(op__) \
case ir::IrNodeTy::op__: \
return Visit(expr->As<ir::op__>());
NODETY_FORALL(__)
default:
LOG(FATAL) << "not supported NodeTy";
#undef __
}
}
#define __m(t__) \
void Visit(const t__* x) override { \
for (auto* n : x->expr_fields()) { \
if (n->defined()) Visit(n); \
} \
}
NODETY_FORALL(__m)
#undef __m
std::unordered_set<void*> visited_;
};
} // namespace
std::set<Expr> CollectIRNodes(Expr expr, std::function<bool(const Expr*)> teller) {
std::set<Expr> exprs;
IrNodesCollector::handler_t handler = [&](const Expr* x) { exprs.insert(*x); };
IrNodesCollector collector(std::move(teller), std::move(handler));
collector.Visit(&expr);
return exprs;
}
*/

bool operator==(Expr a, Expr b) {
if (a.get() == b.get()) return true;
// TODO(Superjomn) implement with a more accurate one
Expand Down
1 change: 0 additions & 1 deletion cinn/ir/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ class Var;
macro__(Power) \
macro__(Product) \
macro__(Sum) \
macro__(Activate) \
macro__(PrimitiveNode) \

#define NODETY_FORALL(__m) \
Expand Down
2 changes: 1 addition & 1 deletion cinn/optim/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ set(srcs remove_nested_block.cc replace_call_with_expr.cc ir_copy.cc
insert_debug_log_callee.cc
lower_function_call_bind_vars.cc
extern_call_process.cc
activate_to_extern_call.cc
map_extern_call.cc
cache_read_write_replace.cc
compute_inline_expand.cc
buffer_assign.cc
Expand Down
1 change: 0 additions & 1 deletion cinn/optim/insert_debug_log_callee.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ struct StoreDebugInfoBuilder : public ir::IRVisitor {
void Visit(const ir::IfThenElse *x) override {}
void Visit(const ir::Block *x) override {}
void Visit(const ir::Call *x) override {}
void Visit(const ir::Activate *x) override {}
void Visit(const ir::Store *x) override {
format_ << x->tensor.as_tensor()->name << "[] = ";
Visit(&x->value);
Expand Down
8 changes: 0 additions & 8 deletions cinn/optim/ir_copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -320,14 +320,6 @@ struct IRCopyVisitor : public ir::IRVisitorBase<Expr> {
return Sum::Make(operands);
}

Expr Visit(const ir::Activate* op) override {
auto arg = Visit(&op->operand(0));

auto n = common::make_shared<ir::Activate>(arg);
n->kind = op->kind;
return n;
}

Expr Visit(const ir::PrimitiveNode* op) override {
std::vector<std::vector<Expr>> arguments;
for (auto& args : op->arguments) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "cinn/optim/activate_to_extern_call.h"
#include "cinn/optim/map_extern_call.h"

#include "cinn/cinn.h"
#include "cinn/ir/ir_mutator.h"
Expand All @@ -7,7 +7,7 @@
namespace cinn {
namespace optim {

void ActivateToExternCall(Expr *e, Target target) {
void MapExternCall(Expr *e, Target target) {
struct Mutator : ir::IRMutator<Expr *> {
Target target;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ namespace cinn {
namespace optim {

/**
* Replace Activate IR nodes with extern call if needed.
* Map the Call nodes to external function call.
* TODO(Suerjomn) consider different backends.
*/
void ActivateToExternCall(Expr *e, Target target);
void MapExternCall(Expr *e, Target target);

} // namespace optim
} // namespace cinn
4 changes: 2 additions & 2 deletions cinn/optim/optimize.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include "cinn/optim/optimize.h"

#include "cinn/ir/ir_printer.h"
#include "cinn/optim/activate_to_extern_call.h"
#include "cinn/optim/cache_read_write_replace.h"
#include "cinn/optim/call_arg_list_to_pod_value.h"
#include "cinn/optim/eliminate_broadcast_in_forloop.h"
Expand All @@ -11,6 +10,7 @@
#include "cinn/optim/ir_copy.h"
#include "cinn/optim/ir_simplify.h"
#include "cinn/optim/lower_function_call_bind_vars.h"
#include "cinn/optim/map_extern_call.h"
#include "cinn/optim/remove_nested_block.h"
#include "cinn/optim/transform_gpu_forloop.h"
#include "cinn/optim/transform_polyfor_to_for.h"
Expand Down Expand Up @@ -38,7 +38,7 @@ Expr Optimize(Expr e, Target target, bool runtime_debug_info) {

RemoveNestedBlock(&copied);

ActivateToExternCall(&copied, target);
MapExternCall(&copied, target);
ExternCallMultiOutputShallowStore(&copied);

Simplify(&copied);
Expand Down
8 changes: 0 additions & 8 deletions cinn/pybind/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -360,14 +360,6 @@ void BindIrIr(py::module *m) {

DEFINE_UNARY_NODE(Minus);
DEFINE_UNARY_NODE(Not);
DEFINE_UNARY_NODE(Activate).def_readwrite("kind", &ir::Activate::kind);
py::enum_<ir::Activate::Kind> kind(py_Activate, "Kind");
kind.value("kTanh", ir::Activate::Kind::kTanh)
.value("kSigmoid", ir::Activate::Kind::kSigmoid)
.value("kExp", ir::Activate::Kind::kExp)
.value("kCeil", ir::Activate::Kind::kCeil)
.value("kFloor", ir::Activate::Kind::kFloor);

#undef DEFINE_UNARY_NODE

py::class_<Var, IrNodeRef> var(*m, "Var");
Expand Down

0 comments on commit 2c48601

Please sign in to comment.