Skip to content

Commit

Permalink
Merge branch 'gyshi/tf-1.15-backport-leaky' into 'tf-1.15-maint'
Browse files Browse the repository at this point in the history
[TF 1.15] backport leakyrelu and tanh Fusion

See merge request TensorFlow/Direct-Optimization/private-tensorflow!588
  • Loading branch information
Zantares committed Dec 16, 2020
2 parents b8c28c8 + d867cfc commit ab3461b
Show file tree
Hide file tree
Showing 18 changed files with 488 additions and 79 deletions.
15 changes: 14 additions & 1 deletion tensorflow/core/graph/mkl_layout_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1767,7 +1767,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
fused_ops == std::vector<string>{"BiasAdd", "Relu6"} ||
fused_ops == std::vector<string>{"BiasAdd", "Elu"} ||
fused_ops == std::vector<string>{"BiasAdd", "Add"} ||
fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu"});
fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu"} ||
fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu6"} ||
fused_ops == std::vector<string>{"BiasAdd", "Add", "Elu"} ||
fused_ops == std::vector<string>{"LeakyRelu"} ||
fused_ops == std::vector<string>{"BiasAdd", "LeakyRelu"} ||
fused_ops == std::vector<string>{"BiasAdd", "Add", "LeakyRelu"});
}

static bool FusedDepthwiseConv2DRewrite(const Node* n) {
Expand Down Expand Up @@ -2742,6 +2747,7 @@ void MklLayoutRewritePass::CopyAttrsFromPadAndFusedConv2D(
float epsilon;
std::vector<string> fused_ops;
DataType Tpaddings;
float leakyrelu_alpha;

// Get all attributes from old node.
TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "T", &T));
Expand All @@ -2752,6 +2758,8 @@ void MklLayoutRewritePass::CopyAttrsFromPadAndFusedConv2D(
TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "dilations", &dilations));
TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "fused_ops", &fused_ops));
TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "epsilon", &epsilon));
TF_CHECK_OK(
GetNodeAttr(fused_conv2d->def(), "leakyrelu_alpha", &leakyrelu_alpha));
TF_CHECK_OK(GetNodeAttr(pad->def(), "Tpaddings", &Tpaddings));

// Add attributes to new node.
Expand All @@ -2764,6 +2772,7 @@ void MklLayoutRewritePass::CopyAttrsFromPadAndFusedConv2D(
nb->Attr("epsilon", epsilon);
nb->Attr("Tpaddings", Tpaddings);
nb->Attr("fused_ops", fused_ops);
nb->Attr("leakyrelu_alpha", leakyrelu_alpha);
}

void MklLayoutRewritePass::CopyAttrsConv2DDepthwiseCheckConstFilter(
Expand Down Expand Up @@ -2932,6 +2941,7 @@ void MklLayoutRewritePass::CopyAttrsFusedConv2D(const Node* orig_node,
std::vector<int32> strides;
std::vector<int32> dilations;
std::vector<string> fused_ops;
float leakyrelu_alpha;

// Get all attributes from old node.
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
Expand All @@ -2942,6 +2952,8 @@ void MklLayoutRewritePass::CopyAttrsFusedConv2D(const Node* orig_node,
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "fused_ops", &fused_ops));
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "epsilon", &epsilon));
TF_CHECK_OK(
GetNodeAttr(orig_node->def(), "leakyrelu_alpha", &leakyrelu_alpha));

Node* filter_node = nullptr;
TF_CHECK_OK(orig_node->input_node(1, &filter_node));
Expand All @@ -2956,6 +2968,7 @@ void MklLayoutRewritePass::CopyAttrsFusedConv2D(const Node* orig_node,
nb->Attr("dilations", dilations);
nb->Attr("fused_ops", fused_ops);
nb->Attr("epsilon", epsilon);
nb->Attr("leakyrelu_alpha", leakyrelu_alpha);
}

void MklLayoutRewritePass::CopyAttrsPooling(const Node* orig_node,
Expand Down
57 changes: 57 additions & 0 deletions tensorflow/core/graph/mkl_layout_pass_test.cc

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions tensorflow/core/grappler/op_types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,8 @@ bool IsImmutableConst(const NodeDef& node) {

bool IsInvGrad(const NodeDef& node) { return node.op() == "InvGrad"; }

bool IsLeakyRelu(const NodeDef& node) { return node.op() == "LeakyRelu"; }

bool IsLess(const NodeDef& node) { return node.op() == "Less"; }

bool IsLessEqual(const NodeDef& node) { return node.op() == "LessEqual"; }
Expand Down Expand Up @@ -546,6 +548,8 @@ bool IsSymbolicGradient(const NodeDef& node) {
return node.op() == "SymbolicGradient";
}

bool IsTanh(const NodeDef& node) { return node.op() == "Tanh"; }

bool IsTanhGrad(const NodeDef& node) { return node.op() == "TanhGrad"; }

bool IsTensorArray(const NodeDef& node) {
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/grappler/op_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ bool IsIgammac(const NodeDef& node);
bool IsImag(const NodeDef& node);
bool IsImmutableConst(const NodeDef& node);
bool IsInvGrad(const NodeDef& node);
bool IsLeakyRelu(const NodeDef& node);
bool IsLess(const NodeDef& node);
bool IsLessEqual(const NodeDef& node);
bool IsLog(const NodeDef& node);
Expand Down Expand Up @@ -182,6 +183,7 @@ bool IsSub(const NodeDef& node);
bool IsSum(const NodeDef& node);
bool IsSwitch(const NodeDef& node);
bool IsSymbolicGradient(const NodeDef& node);
bool IsTanh(const NodeDef& node);
bool IsTanhGrad(const NodeDef& node);
bool IsTensorArray(const NodeDef& node);
bool IsTile(const NodeDef& node);
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/grappler/optimizers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,7 @@ tf_cuda_cc_test(
deps = [
":remapper",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
Expand All @@ -835,6 +836,7 @@ tf_cc_test_mkl(
deps = [
":remapper",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
Expand Down
106 changes: 66 additions & 40 deletions tensorflow/core/grappler/optimizers/mkl_remapper_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/

#ifdef INTEL_MKL
#include "tensorflow/cc/ops/nn_ops_internal.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/devices.h"
Expand All @@ -34,8 +35,9 @@ class MklRemapperTest : public GrapplerTest {
const string kAddV2Op = "AddV2";

protected:
void FuseConv2DWithBiasAndAddNOrAdd(const string& data_format, bool has_relu,
string add_op, bool add_with_bcast) {
void FuseConv2DWithBiasAndAddNOrAdd(const string& data_format,
const string& activation, string add_op,
bool add_with_bcast) {
using ::tensorflow::ops::Placeholder;

tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Expand Down Expand Up @@ -73,27 +75,51 @@ class MklRemapperTest : public GrapplerTest {
if (add_op == kAddNOp) {
auto addn = ops::AddN(s.WithOpName(add_op),
std::initializer_list<Input>{input_addn, bias_add});
if (has_relu) {
auto relu = ops::Relu(s.WithOpName("relu"), addn);
ops::Identity(s.WithOpName("fetch"), relu);
auto activate = s.WithOpName("activation");
auto fetch = s.WithOpName("fetch");
if (activation == "Relu") {
ops::Identity(fetch, ops::Relu(activate, addn));
} else if (activation == "Relu6") {
ops::Identity(fetch, ops::Relu6(activate, addn));
} else if (activation == "Elu") {
ops::Identity(fetch, ops::Elu(activate, addn));
} else if (activation == "LeakyRelu") {
ops::Identity(fetch, ops::internal::LeakyRelu(activate, addn));
} else {
ops::Identity(s.WithOpName("fetch"), addn);
DCHECK(activation == "None");
ops::Identity(fetch, addn);
}
} else if (add_op == kAddV2Op) {
auto add = ops::AddV2(s.WithOpName(add_op), input_addn, bias_add);
if (has_relu) {
auto relu = ops::Relu(s.WithOpName("relu"), add);
ops::Identity(s.WithOpName("fetch"), relu);
auto activate = s.WithOpName("activation");
auto fetch = s.WithOpName("fetch");
if (activation == "Relu") {
ops::Identity(fetch, ops::Relu(activate, add));
} else if (activation == "Relu6") {
ops::Identity(fetch, ops::Relu6(activate, add));
} else if (activation == "Elu") {
ops::Identity(fetch, ops::Elu(activate, add));
} else if (activation == "LeakyRelu") {
ops::Identity(fetch, ops::internal::LeakyRelu(activate, add));
} else {
ops::Identity(s.WithOpName("fetch"), add);
DCHECK(activation == "None");
ops::Identity(fetch, add);
}
} else {
auto add = ops::Add(s.WithOpName(add_op), input_addn, bias_add);
if (has_relu) {
auto relu = ops::Relu(s.WithOpName("relu"), add);
ops::Identity(s.WithOpName("fetch"), relu);
auto activate = s.WithOpName("activation");
auto fetch = s.WithOpName("fetch");
if (activation == "Relu") {
ops::Identity(fetch, ops::Relu(activate, add));
} else if (activation == "Relu6") {
ops::Identity(fetch, ops::Relu6(activate, add));
} else if (activation == "Elu") {
ops::Identity(fetch, ops::Elu(activate, add));
} else if (activation == "LeakyRelu") {
ops::Identity(fetch, ops::internal::LeakyRelu(activate, add));
} else {
ops::Identity(s.WithOpName("fetch"), add);
DCHECK(activation == "None");
ops::Identity(fetch, add);
}
}
auto input_tensor = GenerateRandomTensor<DT_FLOAT>(
Expand Down Expand Up @@ -129,7 +155,7 @@ class MklRemapperTest : public GrapplerTest {
bool check_fusion = !add_with_bcast;
int found = 0;
for (const NodeDef& node : output.node()) {
auto fetch_node_name = has_relu ? "relu" : add_op;
auto fetch_node_name = activation != "None" ? "activation" : add_op;
if (node.name() == fetch_node_name) {
if (check_fusion) {
EXPECT_EQ("_FusedConv2D", node.op());
Expand All @@ -141,19 +167,19 @@ class MklRemapperTest : public GrapplerTest {
EXPECT_EQ("input_addn", node.input(3));

const auto fused_ops = node.attr().at("fused_ops").list().s();
if (has_relu) {
if (activation != "None") {
EXPECT_EQ(3, fused_ops.size());
EXPECT_EQ("BiasAdd", fused_ops[0]);
EXPECT_EQ("Add", fused_ops[1]);
EXPECT_EQ("Relu", fused_ops[2]);
EXPECT_EQ(activation, fused_ops[2]);
} else {
EXPECT_EQ(2, fused_ops.size());
EXPECT_EQ("BiasAdd", fused_ops[0]);
EXPECT_EQ("Add", fused_ops[1]);
}
} else {
if (has_relu) {
EXPECT_EQ(node.op(), "Relu");
if (activation != "None") {
EXPECT_EQ(node.op(), activation);
ASSERT_EQ(node.input_size(), 1);
EXPECT_EQ(node.input(0), add_op);
} else {
Expand All @@ -174,38 +200,38 @@ class MklRemapperTest : public GrapplerTest {
}
};

#define CREATE_CONV2DFUSION_TEST(data_format, addop, relu, bcast) \
TEST_F( \
MklRemapperTest, \
FuseConv2DWithBiasAnd##addop##_##data_format##_relu##relu##_addbcast##bcast) { \
const bool kShouldFuseRelu = relu; \
const bool kIsAddWithBcast = bcast; \
FuseConv2DWithBiasAndAddNOrAdd(#data_format, relu, #addop, bcast); \
#define CREATE_CONV2DFUSION_TEST(data_format, addop, activation, bcast) \
TEST_F( \
MklRemapperTest, \
FuseConv2DWithBiasAnd##addop##_##data_format##_activation##activation##_addbcast##bcast) { \
FuseConv2DWithBiasAndAddNOrAdd(#data_format, #activation, #addop, bcast); \
}

#define CREATE_CONV2DFUSION_ADD_NOBCAST_TEST(addop) \
CREATE_CONV2DFUSION_TEST(NHWC, addop, false, false); \
CREATE_CONV2DFUSION_TEST(NHWC, addop, true, false); \
CREATE_CONV2DFUSION_TEST(NCHW, addop, false, false); \
CREATE_CONV2DFUSION_TEST(NCHW, addop, true, false);
#define CREATE_CONV2DFUSION_ADD_ACTIVATION_TEST(data_format, addop, bcast) \
CREATE_CONV2DFUSION_TEST(data_format, addop, Relu, bcast); \
CREATE_CONV2DFUSION_TEST(data_format, addop, Relu6, bcast); \
CREATE_CONV2DFUSION_TEST(data_format, addop, Elu, bcast); \
CREATE_CONV2DFUSION_TEST(data_format, addop, LeakyRelu, bcast); \
CREATE_CONV2DFUSION_TEST(data_format, addop, None, bcast);

#define CREATE_CONV2DFUSION_ADD_NOBCAST_TEST(addop) \
CREATE_CONV2DFUSION_ADD_ACTIVATION_TEST(NHWC, addop, false); \
CREATE_CONV2DFUSION_ADD_ACTIVATION_TEST(NCHW, addop, false);

CREATE_CONV2DFUSION_ADD_NOBCAST_TEST(AddN);

#define CREATE_CONV2DFUSION_ADD_BCAST_TEST(addop) \
CREATE_CONV2DFUSION_TEST(NHWC, addop, false, false); \
CREATE_CONV2DFUSION_TEST(NHWC, addop, true, false); \
CREATE_CONV2DFUSION_TEST(NCHW, addop, false, false); \
CREATE_CONV2DFUSION_TEST(NCHW, addop, true, false); \
CREATE_CONV2DFUSION_TEST(NHWC, addop, false, true); \
CREATE_CONV2DFUSION_TEST(NHWC, addop, true, true); \
CREATE_CONV2DFUSION_TEST(NCHW, addop, false, true); \
CREATE_CONV2DFUSION_TEST(NCHW, addop, true, true);
#define CREATE_CONV2DFUSION_ADD_BCAST_TEST(addop) \
CREATE_CONV2DFUSION_ADD_ACTIVATION_TEST(NHWC, addop, false); \
CREATE_CONV2DFUSION_ADD_ACTIVATION_TEST(NCHW, addop, false); \
CREATE_CONV2DFUSION_ADD_ACTIVATION_TEST(NHWC, addop, true); \
CREATE_CONV2DFUSION_ADD_ACTIVATION_TEST(NCHW, addop, true);

CREATE_CONV2DFUSION_ADD_BCAST_TEST(Add);
CREATE_CONV2DFUSION_ADD_BCAST_TEST(AddV2);

#undef CREATE_CONV2DFUSION_ADD_NOBCAST_TEST
#undef CREATE_CONV2DFUSION_ADD_BCAST_TEST
#undef CREATE_CONV2DFUSION_ADD_ACTIVATION_TEST
#undef CREATE_CONV2DFUSION_TEST

#define REGISTER_TEST(NAME, T, INPUT) \
Expand Down
Loading

0 comments on commit ab3461b

Please sign in to comment.