[cherry-pick] Fix the CE's bug when axis is specified and weight is provided #36647
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
PR types
Bug fixes
PR changes
APIs
Describe
cherry pick from #36344
Background:
CrossEntropy loss function uses hard label, and when the weight is received at the same time, then if specify the axis other than -1, an error will be reported in the calculation process
Problem location:
gather_nd will be used when calculating the intermediate variable weight_gather, but when the coordinate value is not in the last dimension, gather_nd will make an error. When the input shape is as described in the background, it will cause the above problem
Solution:
When the axis is specified as a dimension other than -1, manually construct a correct permutation and pass it to the subsequent gather_nd function to use, so as to get the weight_gather with the correct shape
Other modifications:
Fixed some error judgment conditions related to axis, and add an test case in unittest