Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] No grad in YourTTS speaker encoder #2348

Closed
Tomiinek opened this issue Feb 16, 2023 · 6 comments
Closed

[Bug] No grad in YourTTS speaker encoder #2348

Tomiinek opened this issue Feb 16, 2023 · 6 comments
Assignees
Labels
bug Something isn't working

Comments

@Tomiinek
Copy link

Tomiinek commented Feb 16, 2023

Describe the bug

Hello guys (CC: @Edresson @WeberJulian), when going through YourTTS code & paper, I noticed that you are calculating the inputs for the speaker encoder with no grads:

def forward(self, x, l2_norm=False):
"""Forward pass of the model.
Args:
x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True`
to compute the spectrogram on-the-fly.
l2_norm (bool): Whether to L2-normalize the outputs.
Shapes:
- x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})`
"""
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False):
x.squeeze_(1)
# if you torch spec compute it otherwise use the mel spec computed by the AP
if self.use_torch_spec:
x = self.torch_spec(x)
if self.log_input:
x = (x + 1e-6).log()
x = self.instancenorm(x).unsqueeze(1)
x = self.conv1(x)
x = self.relu(x)
x = self.bn1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = x.reshape(x.size()[0], -1, x.size()[-1])
w = self.attention(x)
if self.encoder_type == "SAP":
x = torch.sum(x * w, dim=2)
elif self.encoder_type == "ASP":
mu = torch.sum(x * w, dim=2)
sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5))
x = torch.cat((mu, sg), 1)
x = x.view(x.size()[0], -1)
x = self.fc(x)
if l2_norm:
x = torch.nn.functional.normalize(x, p=2, dim=1)
return x

I suspect that the speaker encoder is not producing any gradients, and the speaker consistency loss has no effect.
It looks like this happens:

  • forward gets x with grads
  • the spectrogram is extracted via torch_spec with no grads
  • the output of the speaker encoder has requires grad False (and would produce an exception when you called backward on it since it did not keep activations for grad calculation), but is added to the total loss (which has requires grads se to True)
  • so the call to loss.backward() works as usually, but the speaker encoder does not contribute to the gradients flowing to the generator at all

Could you please check on that?

To Reproduce

import torch

a = torch.tensor(1.0, requires_grad=True)
b = torch.tensor(2.0, requires_grad=True)

with torch.no_grad():
     c = a + b

d = c + 1
e = a + d
e.backward()

print(a.grad) # 1
print(b.grad) # not set

d.backward() # RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Expected behavior

No response

Logs

No response

Environment

Just reading the code and asking :)

Additional context

No response

@Tomiinek Tomiinek added the bug Something isn't working label Feb 16, 2023
@WeberJulian
Copy link
Contributor

Nice catch, that's concerning. Thanks for reporting it. We'll look more into it but it looks like you are right.
This torch.no_grad doesn't affect the speaker encoder during training because there is no parameter before it. But in our case, using the original implementation for SCL, the TTS parameters are before this torch.no_grad. I guess in the paper the slightly better SECS score are explained by the extra training steps...

@Edresson
Copy link
Contributor

Nice catch. Indeed it is an issue. I will submit a PR to fix it.

@Tomiinek
Copy link
Author

Tomiinek commented Feb 23, 2023

Thank you @Edresson

Are you also planning to retrain the models and update the YourTTS paper at least on arxiv? 😇

@Edresson
Copy link
Contributor

Thank you @Edresson

Are you also planning to retrain the models and update the YourTTS paper at least on arxiv? innocent

I think it is not worth because if we do it we will need to recompute the MOS and Sim-MOS. I'm thinking about update the preprint, removing Speaker Consistency Loss from the methodology. And given that the Speaker Consistency Loss had no effect on the results, Speaker Consistency Loss experiments are equal than keep the model training per more 50k steps. In addition, I will try to retracting this issue on ICML published paper as well. Fortunately, It is a minor issue and the reported results are not effected (only the method description that is wrong).

@stale
Copy link

stale bot commented Mar 28, 2023

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. You might also look our discussion channels.

@stale stale bot added the wontfix This will not be worked on but feel free to help. label Mar 28, 2023
@stale stale bot closed this as completed Apr 5, 2023
@erogol erogol removed the wontfix This will not be worked on but feel free to help. label Apr 5, 2023
@Edresson
Copy link
Contributor

@Tomiinek Thanks so much for finding the bug and reporting it. I talked with all authors and the final decision was added a Erratum on YourTTS Github repository and on the last page of the preprint. It is done :).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants