-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-620]Fix flaky test batchnorm training #11544
Conversation
@haojin2 @eric-haibin-lin could you help take a look? Thanks |
"Enable this test with larger atol is better than just disable it." - Sure. The goal now is to do it properly, dive deep and track down the root causes and possible implications. |
@marcoabreu if you take a closer look at "finding 3", we do deliver the same results as the CuDNN version, so the root cause is just that the numeric gradient having a larger error margin compared to the symbolic backward one. |
…chnorm test atol to 1e-2
|
@zheng-da Hi, I have modified your unit test for mkldnn batchnorm, increased atol to 1e-2 as its flaky at 1e-3 and 1e-4. Reasons listed in PR description, all batchnorm implementation (mkldnn, cudnn) gradient is slightly off from numeric gradient. Using 1e-2 is more stable and passed 10000 runs. |
@roywei thanks for fixing the flaky test. |
@@ -1527,7 +1526,7 @@ def check_batchnorm_training(stype): | |||
test = mx.symbol.BatchNorm(data, fix_gamma=False, use_global_stats=True, axis=chaxis) | |||
check_numeric_gradient(test, in_location, xmean_std, numeric_eps=1e-2, rtol=0.2, atol=0.01) | |||
|
|||
stypes = ['row_sparse', 'default'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm just curious. does row_sparse make the test flaky?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without row_sprase, the flakyness come from difference between gradient values of mkldnn/cuddn implementation of batchnorm and numeric calculation from check_numeric_gradients
row_sparse makes it more flaky, and will fail at this part: testing varying channel axis. Which is not in your test_mkddnn:test_batchnorm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But isn't that an actual issue? It means that the two implementations are not equivalent and yield different results, right?
I don't really get the statement about row_sparse being removed because of the mkldnn test. Considering we run without mkldnn, how is that related? I might be missing something here, but I got the feeling that row_sparse is being removed because it's causing trouble in some environment configurations. Could you please elaborate here? |
@marcoabreu We don't have a batch norm implemented for sparse arrays, so 'row_sparse' was there to test if we can fallback correctly. The storage type inference logic for sparse ops share some code paths with the dispatch inference logic for MKLDNN operators, which is actually what we want to test, so that's why it's now moved there. |
I see, that makes sense. Where does it fall back to if MKLDNN is not available? My concern here is that this fallback actually seems to be failing (not generating the expected values) in some cases. This shouldn't happen, right? To me, it seems like we're now removing this problematic test although the problem might lie in whatever algorithm implementation we use in that fallback - and the test actually serving its purpose to reveal that problem. |
@marcoabreu You are right, and when row_sparse is not available, it should fallback to default mxnet implementation, it could be mkldnn or not. However, for mkldnn enabled operators, they changed the fallback logic and introduced a bug tracked at #11448. Thus it make sense to moved the fallback check under mkl/test_mkldnn.py, and for each operator impacted, there should be a similar unit test. @haojin2 is double checking if the fallback values and logic is fixed and correct. |
I think we should leave the tests as they are and there can be additional specific tests for the mkldnn backend. Our users should not have to care what backend is being selected and the behaviour should always be identical - these high level operator tests are implicitly testing the fallback behavior. This means that however I call a certain operator, row_sparse in this case, I would expect that I get a proper result in any case. It's great that the issue is already tracked, but I'm heavily opposed to removing that operator from our test suite. All operator tests have to be implementation independent and should not differentiate or exclude certain cases. From my point of view, this flaky test is only resolved when the mkldnn implementation has been fixed and when this test passes all the time. Removing it is not an option. |
@marcoabreu I think we can merge this first since I've identified the root cause of the inconsistencies between sparse and dense matrices, and will be submitting a fix with corresponding tests soon. |
Well the problem is not really solved but the disabling of a certain test has just been narrowed down - instead of disabling the entire suite, this PR now disables a single sub-test. Right now, the test is disabled and we should leave it like that until the test has been stabilized entirely. I'm happy to merge if we leave the test disabled (re-add the skip annotation) or when you added the mentioned fixes. |
@marcoabreu I just submitted a fix for that, and I need the fix for the flakiness of the existing test in this PR in that PR. |
I just edited by post, please check the last section |
Excellent. Let's submit this PR with the test disabled and you'll take it from there in your PR. How does that sound? |
@marcoabreu The test for fallback correctness should not even appear in this file, this file is for operators under normal dense inputs, the tests related with sparse operators should go to test_sparse_operators.py. As this test can pass more than 10000 times under dense inputs I would say the test on dense inputs is stable already. |
Very good point about the sparse operator tests, I wasn't aware of them. Thank you! I will make a last pass and merge if everything is as expected. |
@marcoabreu Thanks for understanding! Please also take a look at my PR when you have time to do so, thanks! @roywei Thanks for your work on this complicated issue! |
* increase atol to 1e-2 * enable test_batchnorm_training * remove row_sparse as it's tested in another test, increase mkldnn batchnorm test atol to 1e-2
Description
Fix issue #8044
tracked at: https://issues.apache.org/jira/browse/MXNET-620
Solution:
Bump up atol value from 1e-4 to 1e-2
The following 3 tests related to batchnorm training can all pass 10000 runs:
Finding 1:
Using fp32 with atol=1e-2 or fp64 with atol=1e-3 both works, chose the former one.
Tried different numeric_eps=1e-3, didn't work for all cases
Finding 2:
Reason of failing due to batchnorm result slightly different than result calculated by numeric gradient. Current implementation of batchnorm should be correct.
Finding 3:
Verified the gradient of 3 different batchnorm implementation are close but they are all different than numeric gradient. (CPU, GPU without cuDNN, GPU with cuDNN)
Conclusion
See comments in #8044 for detailed results.
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Comments