-
Notifications
You must be signed in to change notification settings - Fork 6.8k
LSTM with MKL-DNN produces wrong output after weights are changed #16037
Comments
Hey, this is the MXNet Label Bot. |
@mxnet-label-bot add [Bug, MKLDNN, RNN] |
Probably it's because the stateful RNN op doesn't check if weight is changed. We will look at this. @pengzhao-intel |
@zixuanweeei Would you please have a look for this? |
@ZhennanQin Sure. Just as you have said, it is definitely caused by that stateful RNN op won't check weights again after it has been initialized with MKL-DNN memory format in inference procedure. |
@matteosal Thanks for you reporting this issue. We are addressing the problem. PR is on the way. Thanks. |
Great! |
@matteosal That's right. The problem won't apear with GRU, because we haven't integrated MKL-DNN GRU into MXNet yet. It will be available in the near future. |
@matteosal thanks to reporting the issues which are really helpful. |
Sure, I'm writing you from my Wolfram work email |
Fixed and closing. Thanks to reporting the issue :) |
Description
mode='lstm'
and bind itThe output doesn't change, unless the second forward pass is performed in training mode (
is_train=True
). SettingMXNET_MKLDNN_ENABLED=0
doesn't fix the issue, but using a build without MKL-DNN does.This severly impacts training with a validation set, because evaluating the performance on the validation set is typically performed with
is_train=False
after several updates of the weights. In this case, validation shows no improvement because the output of the layer is stuck at the very first training iteration.Environment info (Required)
Package used (Python/R/Scala/Julia): python
Build info (Required if built from source)
Compiler (gcc/clang/mingw/visual studio): gcc
MXNet commit hash: 076b2f3
Build config: plain config.mk, except for USE_OPENCV=0
Minimum reproducible example
When using a build with MKL-DNN, this script print something like this:
Which shows that the output doesn't change after changing the weights unless the forward pass is performed in training mode. Setting
MXNET_MKLDNN_ENABLED=0
doesn't fix the issue, but using a build without MKL-DNN does.The text was updated successfully, but these errors were encountered: