Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

oneDNN NHWC fixes #40049

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion paddle/fluid/framework/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,11 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool force_disable_gc, bool keep_kid_scopes) {
platform::RecordBlock b(block_id);
if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc);
auto ctx = Prepare(pdesc, block_id, skip_ref_cnt_vars, force_disable_gc);
#ifdef PADDLE_WITH_MKLDNN
platform::AttachPointerHashToMKLDNNKey(this, place_);
platform::RegisterModelLayout(ctx->ops_, place_);
#endif
auto ctx = Prepare(pdesc, block_id, skip_ref_cnt_vars, force_disable_gc);
RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars,
keep_kid_scopes);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ ResidualConnectionMKLDNNFusePass::ResidualConnectionMKLDNNFusePass() {
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "AnyLayout"})
.IsStringIn({"NHWC", "NCHW", "AnyLayout"})
.End();

AddOpCompat(OpCompat("elementwise_add"))
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/naive_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ void NaiveExecutor::Prepare(Scope *scope, const ProgramDesc &program_desc,
void NaiveExecutor::Run() {
#ifdef PADDLE_WITH_MKLDNN
platform::AttachPointerHashToMKLDNNKey(this, place_);
platform::RegisterModelLayout(ops_, place_);
#endif
platform::ScopedFlushDenormal flush;
for (auto &op : ops_) {
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/lrn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ class LRNOp : public framework::OperatorWithKernel {
auto ar = paddle::framework::AttrReader(attrs);
const std::string data_format = ar.Get<std::string>("data_format");
auto dl = framework::StringToDataLayout(data_format);
// Some models may have intentionally set "AnyLayout" for pool
// Some models may have intentionally set "AnyLayout" for lrn
// op. Treat this as NCHW (default data_format value)
if (dl != framework::DataLayout::kAnyLayout) {
return framework::OpKernelType(expected_kernel_type.data_type_,
Expand Down
9 changes: 2 additions & 7 deletions paddle/fluid/operators/mkldnn/prelu_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,8 @@ class PReluMKLDNNHandler
if (weights->dims().size() != x->dims().size()) {
auto new_weights_dims = std::vector<int64_t>(x->dims().size(), 1);
if (mode == "channel") {
if (data_format == "NHWC") {
new_weights_dims[x->dims().size() - 1] =
*std::max_element(weights_dims.begin(), weights_dims.end());
} else {
new_weights_dims[1] =
*std::max_element(weights_dims.begin(), weights_dims.end());
}
new_weights_dims[1] =
*std::max_element(weights_dims.begin(), weights_dims.end());
}
weights_dims = std::move(new_weights_dims);
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/mkldnn/test_mkldnn_op_nhwc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ TEST(test_pool2d_transpose_nhwc, cpu_place) {

TEST(test_pool2d_relu_relu_nhwc, cpu_place) {
framework::DDim dims({1, 4, 8, 512}); // NHWC shape
framework::DDim expected_dims({1, 512, 3, 7}); // NHWC expected shape
framework::DDim expected_dims({1, 512, 3, 7}); // NCHW expected shape
platform::CPUPlace p;
framework::Scope scope;

Expand Down
34 changes: 33 additions & 1 deletion paddle/fluid/operators/prelu_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,26 @@ limitations under the License. */
namespace paddle {
namespace operators {

framework::OpKernelType innerGetKernelTypeForVar(
const Tensor &tensor, const framework::OpKernelType &expected_kernel_type) {
#ifdef PADDLE_WITH_MKLDNN
auto isOneDNNKernelChosen =
(expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN);
auto isNotOneDNNTensor = (tensor.layout() != framework::DataLayout::kMKLDNN);
auto isModelNHWC =
(paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout() == framework::DataLayout::kNHWC);
// All inputs (including alpha) need shape rotating
if (isOneDNNKernelChosen && isNotOneDNNTensor && isModelNHWC) {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(),
framework::DataLayout::kNHWC);
}
#endif
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}

class PReluOp : public framework::OperatorWithKernel {
public:
PReluOp(const std::string &type, const framework::VariableNameMap &inputs,
Expand Down Expand Up @@ -53,7 +73,7 @@ class PReluOp : public framework::OperatorWithKernel {
"For mode 'channel', data_format must be one of "
"NCHW and NHWC. But recevied data_format: %s",
data_format_str));
if (data_format_str == "NCHW") {
if (data_format_str == "NCHW" || ctx->IsRunMKLDNNKernel()) {
PADDLE_ENFORCE_EQ(
product(ctx->GetInputDim("Alpha")) == x_dim[1], true,
platform::errors::InvalidArgument(
Expand Down Expand Up @@ -128,6 +148,12 @@ class PReluOp : public framework::OperatorWithKernel {
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}

framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const {
return innerGetKernelTypeForVar(tensor, expected_kernel_type);
}
};

class PReluOpMaker : public framework::OpProtoAndCheckerMaker {
Expand Down Expand Up @@ -212,6 +238,12 @@ class PReluGradOp : public framework::OperatorWithKernel {
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}

framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const {
return innerGetKernelTypeForVar(tensor, expected_kernel_type);
}
};

template <typename T>
Expand Down
28 changes: 28 additions & 0 deletions paddle/fluid/platform/mkldnn_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,34 @@ inline void GetGroupConvWeightsTz(std::vector<int64_t>& weights_tz, // NOLINT
}
}

inline void RegisterModelLayout(
std::vector<std::unique_ptr<framework::OperatorBase>>& ops,
const platform::Place& place) {
if (platform::is_cpu_place(place)) {
auto check_attrib = [](std::unique_ptr<framework::OperatorBase>& op,
const std::string& attrib_name) -> bool {
if (op->HasAttr(attrib_name)) {
auto data_format = op->Attr<std::string>(attrib_name);
platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
data_format.compare("NHWC") == 0 ? framework::DataLayout::kNHWC
: framework::DataLayout::kNCHW);
return true;
} else {
return false;
}
};

for (auto& op : ops) {
if (check_attrib(op, std::string("data_format"))) {
return;
}
if (check_attrib(op, std::string("data_layout"))) {
return;
}
}
}
}

inline bool HasOpINT8DataType(const paddle::framework::OpDesc* op) {
return (op->GetAttrIfExists<std::string>("mkldnn_data_type") == "int8" ||
op->GetAttrIfExists<bool>("use_quantizer"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,120 @@
import hypothesis.strategies as st


class TestConvElementwiseAddMkldnnFusePass(PassAutoScanTest):
# the two inputs of elementwise_add are tensor
class TestConvElementwiseAddMkldnnFusePass1(PassAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
attrs = [
program_config.ops[i].attrs
for i in range(len(program_config.ops))
]
# If the problem has been fixed, the judgment
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar condition is in test_mkldnn_conv_mish_fuse_pass.py and test_mkldnn_depthwise_conv.py. Should we enable testing NHWC there too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Silv3S This UT was added in #39654 So I just copied it here as suggessted by @lidanqing-intel . I do not intend to extend it in this PR.

# needs to be deleted!!!
if attrs[1]['data_format'] == "NHWC":
if attrs[1]['data_format'] == "NHWC" and attrs[3]['axis'] == 0:
return False
if attrs[1]['data_format'] == "NCHW" and attrs[3]['axis'] == -1:
return False
return True

def sample_program_config(self, draw):
data_format = draw(st.sampled_from(["NCHW", "NHWC"]))
dilations = draw(st.sampled_from([[1, 1], [2, 2], [1, 2]]))
padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VALID"]))
groups = draw(st.sampled_from([1, 2, 4]))
paddings = draw(st.sampled_from([[0, 3], [1, 1], [1, 2, 3, 4]]))
strides = draw(st.sampled_from([[1, 1], [2, 2], [1, 2]]))
axis = draw(st.sampled_from([-1, 0]))
batch_size = draw(st.integers(min_value=1, max_value=4))

def generate_input():
if data_format == "NCHW":
return np.random.random(
[batch_size, 48, 64, 64]).astype(np.float32)
else:
return np.random.random(
[batch_size, 64, 64, 48]).astype(np.float32)

def generate_weight():
return np.random.random(
[48, int(48 / groups), 3, 3]).astype(np.float32)

relu_op = OpConfig(
type="relu",
inputs={"X": ["input_data"]},
outputs={"Out": ["relu_out"]},
attrs={})

conv2d_op1 = OpConfig(
type="conv2d",
inputs={"Input": ["relu_out"],
"Filter": ["conv_weight1"]},
outputs={"Output": ["conv_output1"]},
attrs={
"data_format": data_format,
"dilations": dilations,
"padding_algorithm": padding_algorithm,
"groups": groups,
"paddings": paddings,
"strides": strides
})

conv2d_op2 = OpConfig(
type="conv2d",
inputs={"Input": ["input_data"],
"Filter": ["conv_weight2"]},
outputs={"Output": ["conv_output2"]},
attrs={
"data_format": data_format,
"dilations": dilations,
"padding_algorithm": padding_algorithm,
"groups": groups,
"paddings": paddings,
"strides": strides
})

elt_op = OpConfig(
type="elementwise_add",
inputs={"X": ["conv_output1"],
"Y": ["conv_output2"]},
outputs={"Out": ["elementwise_output"]},
attrs={'axis': axis})

model_net = [relu_op, conv2d_op1, conv2d_op2, elt_op]

program_config = ProgramConfig(
ops=model_net,
weights={
"conv_weight1": TensorConfig(data_gen=partial(generate_weight)),
"conv_weight2": TensorConfig(data_gen=partial(generate_weight))
},
inputs={
"input_data": TensorConfig(data_gen=partial(generate_input))
},
outputs=["elementwise_output"])

return program_config

def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True)
yield config, ["relu", "conv2d", "conv2d"], (1e-5, 1e-5)

def test(self):
self.run_and_statis(
quant=False, passes=["conv_elementwise_add_mkldnn_fuse_pass"])


'''
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is whole original unit test commented?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just copied this UT from #39654 as requested so not sure of authors intention

class TestConvElementwiseAddMkldnnFusePass(PassAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
attrs = [
program_config.ops[i].attrs
for i in range(len(program_config.ops))
]
if "elementwise_weight" in program_config.weights:
if program_config.weights["elementwise_weight"].shape[0] == program_config.inputs["input_data1"].shape[1]:
if attrs[2]['axis'] != 1:
return False
if program_config.weights["elementwise_weight"].shape[0] == program_config.inputs["input_data1"].shape[3]:
if attrs[2]['axis'] != -1:
return False
return True

def sample_program_config(self, draw):
Expand Down Expand Up @@ -101,7 +204,7 @@ def generate_weight2():
"strides": strides
})

if axis == -1 or axis == 0:
if axis == 0:
elt_op = OpConfig(
type="elementwise_add",
inputs={"X": ["input_data1"],
Expand All @@ -118,14 +221,12 @@ def generate_weight2():

model_net = [relu_op, conv2d_op, elt_op]

if axis == 1:
if axis == 0:
program_config = ProgramConfig(
ops=model_net,
weights={
"conv_weight":
TensorConfig(data_gen=partial(generate_weight1)),
"elementwise_weight":
TensorConfig(data_gen=partial(generate_weight2))
TensorConfig(data_gen=partial(generate_weight1))
},
inputs={
"input_data1":
Expand All @@ -137,7 +238,9 @@ def generate_weight2():
ops=model_net,
weights={
"conv_weight":
TensorConfig(data_gen=partial(generate_weight1))
TensorConfig(data_gen=partial(generate_weight1)),
"elementwise_weight":
TensorConfig(data_gen=partial(generate_weight2))
},
inputs={
"input_data1":
Expand All @@ -154,7 +257,7 @@ def sample_predictor_configs(self, program_config):
def test(self):
self.run_and_statis(
quant=False, passes=["conv_elementwise_add_mkldnn_fuse_pass"])

'''

if __name__ == "__main__":
unittest.main()