-
Notifications
You must be signed in to change notification settings - Fork 434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ERROR - tf2onnx.tfonnx: Failed to convert node 'StatefulPartitionedCall/functional_1/layer_normalization/FusedBatchNormV3' #1175
Comments
Looking at the source code in tensorflow the layer does the following on lines 1217-1265
The comment suggests that two placeholder tensors are created with the correct shape to be able to call nn.fused_batch_norm. Note that the conversion works using the non-fused batch normalisation inside the LayerNormalization layer. So is this a feature that is just not supported in the current version of tf2onnx? |
tf2onnx only supports converting models for inference, not training. I think the above tf behavior is only needed during training, since the scale/offset are constant for inference models. |
Hi @TomWildenhain-Microsoft, thanks for the reply. If that is indeed the case, is it possible for tf2onnx to skip these layers? I only want the onnx model for inference but the saved model obviously has these layers in from training. |
This is my current understanding, but @guschmue correct me if any of this is wrong: BatchNorm has a property "is_training" which can be true or false. When true, the mean and variance values are computed dynamically during training. When false, they are frozen and stored in the op. If is_training is true when we try to convert the model, the mean and variance values are empty so we don't have enough information to run inference. You can pass the To convert a keras model with batch normalization, you must set the layer to trainable=false before saving the model. I've been trying to do this with the model you provided but keep getting trainable=true despite using set_learning_phase(0) before saving. Might be related to this issue: @guschmue do you know how to set is_training to false on a batch norm layer in Keras? |
I also observe a similar problem when I set the LayerNormalization layers to Trainable=False at training time, meaning I just use fixed normalisation parameters for these layers, when I load the saved model back up again the layers appear to have training=True as all parameters are learnable again. I don't know if this is intended behaviour or not. However it is worth noting that this problem doesn't occur at all for normal BatchNormalization inside the LayerNormalization layer. This can be achieved by setting epsilon < 1e-5 according to the comments in the tensorflow source code normalization.py lines 1120-1125:
The layer then calls batch_normalization() instead of fused_batch_norm() and this is then converted successfully with tf2onnx.convert. |
@HarryAA thanks for looking into this. I've recently improved our messaging for when training is set to true (tf2onnx will display a warning but continue conversion). Ideally models should have training set to false when we convert them, but it seems like there's a bug in TF/keras that isn't setting that correctly on this op. Can you try converting with the latest tf2onnx from master and see if it fixes your issue?
|
Thanks for the update @TomWildenhain-Microsoft , sorry for the delay. I have tried the latest version of onnx as you suggested and it does indeed convert the model successfully. However when I load the onnx model into a runtime session I get the following error:
Which suggests to me that this fix maybe just avoids the underlying issue instead of fixing it? |
@HarryAA Yes I think you are right about that. My current conclusion is that there is a bug in Keras/TF that makes it not properly set the training value to false. If the training value is true, tf2onnx has an error. The best solution would be to work around the keras bug to get training to be false. If training is true, I'm not sure there is much we can do to convert the model. We can't just skip the layer, since I think the values produced would be incorrect. Does that seem right? |
I also encountered the same problem,from OpenNMT-tf, the SequenceToSequence Model. |
@TomWildenhain-Microsoft Yes I think it does. The layers definitely need to be included in the conversion because the learnt parameters are used at inference. I have a work around that prevents me from using the FusedBatchNorm and this works for me but it certainly isn't a general solution. I am not sure if the bug in tensorflow is fixed and we are able to set training=False that this problem will actually go away though |
I'm pretty sure it will. Setting training to false also mains TF will include additional data in the op that we need (mean and variance). When training = true, these values are left blank. https://www.tensorflow.org/api_docs/python/tf/raw_ops/FusedBatchNormV3 |
Okay, that makes sense. So this is an issue that should be raised in the TF repo. Has this been done? Until then this issue can't be fixed. |
I believe I have seen this issue raised before, but I can't find it... would be worth raising again. My only other question is how TF is able to run inference at all without these values. If it computes the mean/variance from the current sample, we may be able to do the same in this case. If it is using some sort of rolling average based on previous inferences, then there really is nothing we can do since ONNX inference is stateless. As a workaround, I think I saw someone say you can "hack" the training value to false by accessing some private properties. Not sure where I saw that... In any case, any more research you make in this area would be appreciated. |
Okay sure, I will take a look a bit closer at what TF is doing for this when performing inference. I'll try and avoid hacky workarounds for the time being and see if I can fix the underlying issue. Cheers. |
I took a deep look at this issue today and here's what I found:
The solution is to insert ops for computing the mean/variance when a FusedBatchNormV3 is found to be in training mode and missing those values. I've done this in #1249 which will hopefully fix your issue. On the example model you provided, it does produce the correct answer. |
Thanks very much for fixing @TomWildenhain-Microsoft ! |
Hi @TomWildenhain-Microsoft , I was using tf2onnx 1.8.1 to generate the onnx for automl efficientdet saved model. I'm still getting a lot of warnings regarding the FusedBatchNormV3 |
@romil611 have you been able to test whether the onnx model produces the correct results? If so, it's fine to ignore the warnings. Otherwise, please upload a zipped copy of the saved model and the onnx file you are getting. You might be able to set training to false to fix it, but I've had difficulty with that in the past. |
@TomWildenhain-Microsoft I was able to get rid of that warning when I used tf2onnx 1.10. Anyways thanks for responding! |
Describe the bug
When trying to convert a tensorflow model containing a tf.keras.layers.LayerNormalization layer, the conversion to onnx fails. This occurs at the FusedBatchNormV3 node when attempting to resize the mean input to have the same shape as the scale input. This produces a ValueError: negative dimensions are not allowed.
Urgency
None.
System information
To Reproduce
Describe steps/code to reproduce the behavior:
Create a model containing tf.keras.layers.LayerNormalization layer such as:
Then save model and attempt to convert using tf2onnx.convert and it will fail.
Screenshots
Additional context
I have included a screenshot of tensorboard showing the node in question
The text was updated successfully, but these errors were encountered: