From 469106115c49682b25038a666fd71bd4a10fb66b Mon Sep 17 00:00:00 2001 From: jakpiase Date: Tue, 5 Jul 2022 17:14:26 +0200 Subject: [PATCH] changes to GetKernelTypeForVar --- paddle/fluid/operators/pad2d_op.cc | 14 ++++++-------- paddle/fluid/operators/pad3d_op.cc | 14 ++++++-------- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/operators/pad2d_op.cc b/paddle/fluid/operators/pad2d_op.cc index e7f0c6507bf70b..de45a2ff811cde 100644 --- a/paddle/fluid/operators/pad2d_op.cc +++ b/paddle/fluid/operators/pad2d_op.cc @@ -722,14 +722,12 @@ class Pad2dOp : public framework::OperatorWithKernel { const framework::OpKernelType& expected_kernel_type) const { #ifdef PADDLE_WITH_MKLDNN if ((expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) && - (tensor.layout() != framework::DataLayout::kMKLDNN)) { - auto attrs = Attrs(); - auto ar = paddle::framework::AttrReader(attrs); - const std::string data_format = ar.Get("data_format"); - return framework::OpKernelType( - expected_kernel_type.data_type_, - tensor.place(), - framework::StringToDataLayout(data_format)); + (tensor.layout() != framework::DataLayout::kMKLDNN) && + paddle::platform::MKLDNNDeviceContext::tls() + .get_cur_paddle_data_layout() == framework::DataLayout::kNHWC) { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), + framework::DataLayout::kNHWC); } #endif return framework::OpKernelType( diff --git a/paddle/fluid/operators/pad3d_op.cc b/paddle/fluid/operators/pad3d_op.cc index e4b32b3d7a76ea..7d4f4826cae888 100644 --- a/paddle/fluid/operators/pad3d_op.cc +++ b/paddle/fluid/operators/pad3d_op.cc @@ -57,14 +57,12 @@ class Pad3dOp : public framework::OperatorWithKernel { const framework::OpKernelType& expected_kernel_type) const { #ifdef PADDLE_WITH_MKLDNN if ((expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) && - (tensor.layout() != framework::DataLayout::kMKLDNN)) { - auto attrs = Attrs(); - auto ar = paddle::framework::AttrReader(attrs); - const std::string data_format = ar.Get("data_format"); - return framework::OpKernelType( - expected_kernel_type.data_type_, - tensor.place(), - framework::StringToDataLayout(data_format)); + (tensor.layout() != framework::DataLayout::kMKLDNN) && + paddle::platform::MKLDNNDeviceContext::tls() + .get_cur_paddle_data_layout() == framework::DataLayout::kNHWC) { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), + framework::DataLayout::kNHWC); } #endif return framework::OpKernelType(