Skip to content

Commit

Permalink
Add gelu fuse ops (apache#18082)
Browse files Browse the repository at this point in the history
* Add LeakyReLU:Gelu (fwd and bwd) to fused ops

* Add test LeakyReLU:gelu

* cpplint

* fix lint

* fix bug SQRT_2 using constant memory

* add comments
  • Loading branch information
MoisesHer authored and AntiZpvoh committed Jul 6, 2020
1 parent d786f62 commit 67c7b55
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 0 deletions.
14 changes: 14 additions & 0 deletions src/executor/pointwise_fusion_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,20 @@ namespace {
op_name) !=
variable_io_ops.end())
return true;
if (op_name == "LeakyReLU") {
std::string act_type = n->attrs.dict.at("act_type");
if (LeakyReLU_ops.count(act_type))
return true;
else
return false;
}
if (op_name == "_backward_LeakyReLU") {
std::string act_type = n->attrs.dict.at("act_type");
if (LeakyReLU_bwd_ops.count(act_type))
return true;
else
return false;
}
return false;
}

Expand Down
23 changes: 23 additions & 0 deletions src/operator/fusion/fused_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,14 @@ const std::map<std::string, std::vector<std::vector<std::string>>> ops_desc = {
{"(% * % / op::hypot(%, %))", "_0", "_2", "_1", "_2"}}}
};

// LeakyReLU ops: based on "act_type" attribute
const std::map<std::string, std::vector<std::vector<std::string>>> LeakyReLU_ops = {
{"gelu" , {{"op::gelu(%)", "_0"}}},
};
const std::map<std::string, std::vector<std::vector<std::string>>> LeakyReLU_bwd_ops = {
{"gelu" , {{"op::backward_gelu(%, %)", "_1", "_0"}}},
};

const std::map<std::string, std::string> slice_ops = {
{"slice_axis" , ""},
{"slice" , ""},
Expand Down Expand Up @@ -543,6 +551,14 @@ __device__ inline DType relu(const DType val) {
return val > 0 ? val : 0;
}
const float SQRT_2 = 1.4142135623730950488016887242096;
// compatible with mshadow_op.h version
template <typename DType>
__device__ inline DType gelu(const DType val) {
return DType(0.5f * static_cast<float>(val) *
(1.0f + erf(static_cast<float>(val) / SQRT_2)));
}
template <typename DType>
__device__ inline DType sigmoid(const DType val) {
return 1.f/(1 + expf(-val));
Expand Down Expand Up @@ -987,6 +1003,13 @@ __device__ inline DTypeGrad backward_smooth_l1(const DType val, const DType2 sca
}
}
// compatible with mshadow_op.h version
template <typename DType, typename DTypeGrad>
__device__ inline DTypeGrad backward_gelu(const DType val, const DTypeGrad grad) {
return grad * DType(0.5f * (1.0f + erf(static_cast<float>(val) / SQRT_2) +
static_cast<float>(val) * backward_erf(static_cast<float>(val) / SQRT_2, 1.0f) / SQRT_2));
}
} // namespace op
)code";
Expand Down
36 changes: 36 additions & 0 deletions src/operator/fusion/fused_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,42 @@ std::string FusedOp::GenerateCode(const std::vector<OpReqType> &req,
continue;
}

// LeakyReLU, look for act_type
if (op_name == "LeakyReLU") {
std::string act_type = node.source->attrs.dict.at("act_type");
const std::vector<std::vector<std::string>>& op_descs =
fusion::LeakyReLU_ops.at(act_type);
if (fusion::LeakyReLU_ops.find(act_type) != fusion::LeakyReLU_ops.end()) {
CHECK_EQ(outputs[i], op_descs.size());
size_t count = 0;
for (const auto& op_desc : op_descs) {
var_name = "temp" + std::to_string(temp_name_counter++);
const std::string& fmt = ParseOpDescription(op_desc, variables, node);
code += "const auto " + var_name + " = " + fmt + ";\n";
variables[{i, count}] = var_name;
++count;
}
continue;
}
}
if (op_name == "_backward_LeakyReLU") {
std::string act_type = node.source->attrs.dict.at("act_type");
const std::vector<std::vector<std::string>>& op_descs =
fusion::LeakyReLU_bwd_ops.at(act_type);
if (fusion::LeakyReLU_ops.find(act_type) != fusion::LeakyReLU_bwd_ops.end()) {
CHECK_EQ(outputs[i], op_descs.size());
size_t count = 0;
for (const auto& op_desc : op_descs) {
var_name = "temp" + std::to_string(temp_name_counter++);
const std::string& fmt = ParseOpDescription(op_desc, variables, node);
code += "const auto " + var_name + " = " + fmt + ";\n";
variables[{i, count}] = var_name;
++count;
}
continue;
}
}

LOG(FATAL) << "Unrecognized op " + op_name;
}
} else {
Expand Down
13 changes: 13 additions & 0 deletions tests/python/gpu/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,11 +230,24 @@ def check_other_ops():
arr2 = mx.random.uniform(shape=(2,2,2,3))
check_fused_symbol(mx.sym.broadcast_like(a, b, lhs_axes=[0], rhs_axes=[0]), a=arr1, b=arr2)

def check_leakyrelu_ops():
a = mx.sym.Variable('a')
b = mx.sym.Variable('b')
shape = rand_shape_2d()
arr1 = mx.random.uniform(shape=shape)
arr2 = mx.random.uniform(shape=shape)

# Testing gelu
print("Checking fusion of LeakyReLU:gelu")
check_fused_symbol(mx.sym.LeakyReLU(a+b, act_type='gelu'), a=arr1, b=arr2)


@with_seed()
def test_fusion():
check_unary_ops()
check_binary_ops()
check_other_ops()
check_leakyrelu_ops()

@with_seed()
def test_fusion_compiler_cache():
Expand Down

0 comments on commit 67c7b55

Please sign in to comment.