-
Notifications
You must be signed in to change notification settings - Fork 6.8k
fixing var-seq-len rnn backward() operator #15278
Conversation
For perspective, this is the output of my unit test currently showing the magnitude of the difference for one of the gradients between what I get and what I expect:
|
Just to keep the ticket updated: I have confirmed the following facts:
And I look at the resulting output and see:
And this does match up to the corresponding call to the backward functions, i.e.
And same for cudnnRNNBackwardWeightsEx(). My suspicion now is maybe the reference net gradient is losing floating point precision because it is going through extra reverse / concat / etc operations. Going to consider another way of constructing the reference net for testing the gradient. |
Okay I think the PR is good now. The problem was indeed what I speculated before: the cudnn backward pass was producing the correct gradient, and the reference net was "close but not close enough" to it, and the reason for the discrepancy was in the reference net. I changed the reference net to just be a dead simple LSTM, where I just process each batch element one-at-a-time, so that each time the LSTM can size itself appropriately to the current input. For the backward pass I set the gradient parameter to accumulate so that I can compare against the LSTM using sequence_length with a batch. The unit test now tests the backward pass, and it is successful. |
Looks like test failed due to tolerance issue. (I tried with same exact random seed on my compute instance and test passed however). Trying again with updated tolerance values of (Quick aside: how do people usually look at the test log files? It keeps crashing my browser before I can click on the "get raw text" link) |
Description
This PR fixes the problem mentioned here: #15268
NOTE: this is in a preliminary state currently. The backward() pass now works without crashing and produces a gradient.
HOWEVER, the unit test I added to confirm the gradient we get matches what we expect is failing because the gradients do not match.
Due to time-sensitive nature of the PR I am creating it early in case more eyeballs can help.
@szha @roywei