You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This should be 1 - f, according to the paper. Confusion arose around the effect of the "forget" gate (in LSTM and GRU papers, information is passed through when f is high, but in MGU paper it is the opposite). Variable f from the MGU paper, is effectively 1 - f in Flax (it is the portion that is contributes to short-term response, or n in Flax-speak). From the paper:
In MGU, the forget gate f_t is first generated, and the element-wise product between 1 - f_t and h_{t−1} becomes part of the new hidden state h_t. The portion of h_{t-1} that is "forgotten" (f_t h_{t−1}) is combined with x_t to produce h_bar_t, the short-term response. A portion of h_bar_t (determined again by f_t) form the second part of h_t.
The text was updated successfully, but these errors were encountered:
I actually now think that Flax code is good, the problem is with the paper. I came across this blog post which summarises it well: https://zjusticy.github.io/blog/A-Modified-Minimal-Gated-Unit-(MGU)-Structure. I have run my own experiments and can confirm that verbatim implementation of the paper will lead to a conflict in h_t (new state) and h_bar_t (candidate state). I will run a couple of more experiments before closing the issue.
flax/flax/linen/recurrent.py
Line 725 in d59132d
This should be
1 - f
, according to the paper. Confusion arose around the effect of the "forget" gate (in LSTM and GRU papers, information is passed through whenf
is high, but in MGU paper it is the opposite). Variablef
from the MGU paper, is effectively1 - f
in Flax (it is the portion that is contributes to short-term response, orn
in Flax-speak). From the paper:The text was updated successfully, but these errors were encountered: