-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
oneDNN NHWC fixes #40049
oneDNN NHWC fixes #40049
Conversation
Thanks for your contribution! |
1763e97
to
70869b8
Compare
1db6f35
to
e576182
Compare
e576182
to
72fd182
Compare
- fix - compilation fixes - fix - fixe - fix - fix - compilation fix - comment fix - lint update mkldnn conv_elementwise_add_fuse_pass ut - NHWC changes to prelu - alhpa dims - UT fix - fix to UT - lint - Some fixes - added to BWD of prelu NHWC support - reverted removal of resetting cu_layout in clearing of caching
@baoachun Hi, this PR passed all CIs please review. |
auto check_attrib = [&](std::unique_ptr<framework::OperatorBase>& op, | ||
const std::string& attrib_name) -> bool { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto check_attrib = [&](std::unique_ptr<framework::OperatorBase>& op, | |
const std::string& attrib_name) -> bool { | |
auto check_attrib = [](std::unique_ptr<framework::OperatorBase>& op, | |
const std::string& attrib_name) -> bool { |
python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_elementwise_add_fuse_pass.py
Show resolved
Hide resolved
paddle/fluid/operators/prelu_op.cc
Outdated
framework::OpKernelType GetKernelTypeForVar( | ||
const std::string &var_name, const Tensor &tensor, | ||
const framework::OpKernelType &expected_kernel_type) const { | ||
#ifdef PADDLE_WITH_MKLDNN | ||
// All inputs (including alpha) need shape rotating | ||
if ((expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) && | ||
(tensor.layout() != framework::DataLayout::kMKLDNN) && | ||
paddle::platform::MKLDNNDeviceContext::tls() | ||
.get_cur_paddle_data_layout() == framework::DataLayout::kNHWC) { | ||
return framework::OpKernelType(expected_kernel_type.data_type_, | ||
tensor.place(), | ||
framework::DataLayout::kNHWC); | ||
} | ||
#endif | ||
return framework::OpKernelType(expected_kernel_type.data_type_, | ||
tensor.place(), tensor.layout()); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't repeat yourself, please.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
def is_program_valid(self, program_config: ProgramConfig) -> bool: | ||
attrs = [ | ||
program_config.ops[i].attrs | ||
for i in range(len(program_config.ops)) | ||
] | ||
# If the problem has been fixed, the judgment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar condition is in test_mkldnn_conv_mish_fuse_pass.py and test_mkldnn_depthwise_conv.py. Should we enable testing NHWC there too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Silv3S This UT was added in #39654 So I just copied it here as suggessted by @lidanqing-intel . I do not intend to extend it in this PR.
quant=False, passes=["conv_elementwise_add_mkldnn_fuse_pass"]) | ||
|
||
|
||
''' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is whole original unit test commented?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just copied this UT from #39654 as requested so not sure of authors intention
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
LGTM |
@baoachun Please review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Bug fixes
PR changes
Others
Describe
oneDNN NHWC implementation need to know upfront if model is NHWC or NCHW so mechanics was added to check that.
Also relevant UT was added from #39654