-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MKLDNN-DO Subgraph Optimization #966
MKLDNN-DO Subgraph Optimization #966
Conversation
/azp run |
Azure Pipelines successfully started running 10 pipeline(s). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Mkl and MKL are used inconsistently.
- Put the MklKernel code in a cc file (not header).
- Prefer using mkldnn instead of mkl everywhere to avoid confusion with mkl library.
ORT_UNUSED_PARAMETER(attributes_prefix); | ||
} | ||
|
||
virtual Status CreatePrimitives(const ONNXRunTimeTensor* input_tensors, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input_tensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think input_tensors is correct. We access shape, dim and data by an index. Can you please confirm?
mkldnn::memory::dims src_dims(x_shape.GetDims().begin(), x_shape.GetDims().end()); | ||
std::string key; | ||
key.reserve(128); | ||
key = subgraph_key_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can AddDimsToKey be written such that it accepts a const string and returns the newly formed key? Moreover, we should rename this to GenerateKey() as addition is just one way to generate the key.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pranavsharma AddDimsToKey is widely used in Vanilla mkldnn operators. Should I go ahead and modify it?
can you add some high level documentation, which describes the classes, abstractions and their relationships.
|
} | ||
virtual ~MklKernel(){}; | ||
|
||
virtual void ReadAttributes(const std::unordered_map<std::string, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it possible to leverage OpNodeProtoHelper for reading attributes?
then you wouldn't need to define all the Get*Attr() methods below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I looked into it and it's not straight forward. Can this be done after this PR?
|
||
namespace onnxruntime { | ||
|
||
struct MklNode { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there's already onnxruntime representations of Node, subgraphs etc.
Were those too heavy to use? Would have been nice to leverage more of onnxruntime classes where possible,
rather than adding new ones to bridge onnxruntime and mkldnn library.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MklDnnNode is totally different and I am using it for MklDnn IR of a subgraph.
I tried this PR on a production model and got a crash. Here is the call stack: #0 0x000000000047e487 in std::__find_if<__gnu_cxx::__normal_iterator<std::__cxx11::basic_string<char, std::char_traits, std::allocator >, std::vector<std::__cxx11::basic_string<char, std::char_traits, std::allocator >, std::allocator<std::__cxx11::basic_string<char, std::char_traits, std::allocator > > > >, __gnu_cxx::__ops::_Iter_equals_val<std::__cxx11::basic_string<char, std::char_traits, std::allocator > const> > (__first="Conv", |
while (node_index < graph_viewer.MaxNodeIndex()) { | ||
auto node = graph_viewer.GetNode(node_index); | ||
std::vector<std::string>::iterator it = std::find( | ||
mkl_ops.begin(), mkl_ops.end(), node->OpType()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the crash is because node is nullptr. Need to figure out why node is null.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yufenglee is this PR model available on ONNX Model Zoo?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sreekanth-yalachigere , no, it is a production model and I can't share it with you for confidentiality.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GetNode() may return nullptr if the node has been freed , which can happen if there are some transforms.
I don't think you can iterate through nodes by index from 0 to max.
Shouldn't you use GetNodesInTopologicalOrder() ?
|
Creating new PR as many files have modified after memcpy code merge. |
This optimization creates a subgraph of mkldnn operators and uses full potential of mkldnn by using block propagation and improves performance significantly.
For example, in Resnet50, currently we have 109 data reorders.
With this optimization, we will have only one data reorder and we propage blocked mkldnn memory all the way to the end of subgraph.
onnxruntime_data(nchw) -> conv(nchw8c) -> batchnorm-relu(nchw8c)->pool(nchw8c)->conv(nchw8c)... reorder(nchw)
Subgraph optimization can be enabled by setting the following environment variable.
ORT_MKLDNN_SUBGRAPH=1