Replies: 5 comments
-
>>> erogol |
Beta Was this translation helpful? Give feedback.
-
>>> justin_tian |
Beta Was this translation helpful? Give feedback.
-
>>> justin_tian |
Beta Was this translation helpful? Give feedback.
-
>>> erogol |
Beta Was this translation helpful? Give feedback.
-
>>> justin_tian |
Beta Was this translation helpful? Give feedback.
-
>>> justin_tian
[September 10, 2020, 3:49pm]
I'm trying to implement graves (GMM) attention based on Mozilla TTS
repo. Here is a link with brief discussions about the implementation by
the repo holder. Code below is my implementation to fit
flowtron(https://github.com/NVIDIA/flowtron).
The author of flowtron make some changes to Tacotron. They feed all
encoded mel frames once so the dimension of queries, keys, and values
are
queries: T_mel slash curl-run-all.sh discourse.mozilla.org html-to-markdown.sh ordered-posts ordered-posts~ TTS.cdx tts.commands tts-emails.txt TTS.pages tts-telegram.txt TTS.warc.gz B slash curl-run-all.sh discourse.mozilla.org html-to-markdown.sh ordered-posts ordered-posts~ TTS.cdx tts.commands tts-emails.txt TTS.pages tts-telegram.txt TTS.warc.gz attn_hidden_dim slash
keys, values: T_text slash curl-run-all.sh discourse.mozilla.org html-to-markdown.sh ordered-posts ordered-posts~ TTS.cdx tts.commands tts-emails.txt TTS.pages tts-telegram.txt TTS.warc.gz B slash curl-run-all.sh discourse.mozilla.org html-to-markdown.sh ordered-posts ordered-posts~ TTS.cdx tts.commands tts-emails.txt TTS.pages tts-telegram.txt TTS.warc.gz text_embedding_dim
When I train it, it just doesn't work well. Only the first frame
alignment is close to maximum value and other frame attn. scores are
really low. Also, the converging speed is slow (comparing to an additive
attn). Can someone help me out to see which part of the code needs a fix
(possibly the mu_t part)?
class GravesAttention(torch.nn.Module):
def init(self, n_mel_channels=80, n_speaker_dim=128,
n_text_channels=512, n_att_channels=256, K=4):
super(GravesAttention, self).init()
## K is number of gaussian component
self.K = K
self.mask_value = 1e-8
self.eps = 1e-5
self.J = None
self.N_a = nn.Sequential(
nn.Linear(n_mel_channels, n_mel_channels, bias=True),
nn.ReLU(),
nn.Linear(n_mel_channels, 3 curl-run-all.sh discourse.mozilla.org html-to-markdown.sh ordered-posts ordered-posts~ TTS.cdx tts.commands tts-emails.txt TTS.pages tts-telegram.txt TTS.warc.gz K, bias=True)
)
self.key = LinearNorm(n_text_channels + n_speaker_dim,
n_att_channels, bias=False, w_init_gain='tanh')
self.value = LinearNorm(n_text_channels + n_speaker_dim,
n_att_channels, bias=False,
w_init_gain='tanh')
self.init_layers()
def init_layers(self):
torch.nn.init.constant(self.N_a[2].bias[(2 curl-run-all.sh discourse.mozilla.org html-to-markdown.sh ordered-posts ordered-posts~ TTS.cdx tts.commands tts-emails.txt TTS.pages tts-telegram.txt TTS.warc.gz self.K):(3 curl-run-all.sh discourse.mozilla.org html-to-markdown.sh ordered-posts ordered-posts~ TTS.cdx tts.commands tts-emails.txt TTS.pages tts-telegram.txt TTS.warc.gz self.K)], 1.) # bias mean
torch.nn.init.constant_(self.N_a[2].bias[self.K:(2 curl-run-all.sh discourse.mozilla.org html-to-markdown.sh ordered-posts ordered-posts~ TTS.cdx tts.commands tts-emails.txt TTS.pages tts-telegram.txt TTS.warc.gz self.K)], 10) # bias std
def init_states(self, inputs):
if self.J is None or inputs.shape[0] + 1 > self.J.shape[-1]:
self.J = torch.arange(0, inputs.shape[0] + 2.0).to(inputs.device) + 0.5
def forward(self, queries, keys, values, mask=None, attn=None):
self.init_states(keys) ##initialize self.J
if attn is None:
keys = self.key(keys).transpose(0, 1) # B x in_lens x n_attn_channels
values = self.value(values) if hasattr(self, 'value') else values
values = values.transpose(0, 1) # B x in_lens x n_attn_channels
gbk_t = self.N_a(queries).transpose(0, 1) # B x T x 3K
gbk_t = gbk_t.view(gbk_t.size(0), gbk_t.size(1), -1, self.K)
# each B x T x K
g_t = gbk_t[:, :, 0, :]
b_t = gbk_t[:, :, 1, :]
k_t = gbk_t[:, :, 2, :]
g_t = torch.nn.functional.dropout(g_t, p=0.5, training=self.training)
sig_t = torch.nn.functional.softplus(b_t) + self.eps
k_t = torch.nn.functional.softplus(k_t)
mu_t = torch.cumsum(k_t, dim=1) ## mu_t = mu_(t-1) + k_t, mu_0 = 0
g_t = torch.softmax(g_t, dim=-1) + self.eps
j = self.J[:values.size(1) + 1]
phi_t = g_t.unsqueeze(-1) curl-run-all.sh discourse.mozilla.org html-to-markdown.sh ordered-posts ordered-posts~ TTS.cdx tts.commands tts-emails.txt TTS.pages tts-telegram.txt TTS.warc.gz (1 / (1 + torch.sigmoid((mu_t.unsqueeze(-1) - j) /
sig_t.unsqueeze(-1))))
alpha_t = torch.sum(phi_t, 2) ## sum over attn heads
alpha_t = alpha_t[:, :, 1:] - alpha_t[:, :, :-1]
alpha_t[alpha_t == 0] = 1e-8
if mask is not None:
alpha_t.data.masked_fill_(mask.transpose(1, 2), self._mask_value)
else:
values = self.value(values)
values = values.transpose(0, 1)
print('with_dropout flows2 max, min in alpha_t {} {}'.format(torch.max(alpha_t), torch.min(alpha_t)))
output = torch.bmm(alpha_t, values)
output = output.transpose(1, 2)
return output, alpha_t
[This is an archived TTS discussion thread from discourse.mozilla.org/t/need-help-with-implementing-graves-attention-layer]
Beta Was this translation helpful? Give feedback.
All reactions