Skip to content

Commit

Permalink
[Bugfix][TTS] wrong order of returned tuple for general_collate_fn. (#…
Browse files Browse the repository at this point in the history
…4388)

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
  • Loading branch information
XuesongYang authored Jun 17, 2022
1 parent 0322b15 commit 317739f
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions nemo/collections/tts/torch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ def __init__(
log_mel_folder (Optional[Union[Path, str]]): The folder that contains or will contain log mel spectrograms.
align_prior_matrix_folder (Optional[Union[Path, str]]): The folder that contains or will contain align prior matrices.
pitch_folder (Optional[Union[Path, str]]): The folder that contains or will contain pitch.
voiced_mask_folder (Optional[Union[Path, str]]): The folder that contains or will contain voiced mask of the pitch
p_voiced_folder (Optional[Union[Path, str]]): The folder that contains or will contain p_voiced(probability) of the pitch
voiced_mask_folder (Optional[Union[Path, str]]): The folder that contains or will contain voiced mask of the pitch
p_voiced_folder (Optional[Union[Path, str]]): The folder that contains or will contain p_voiced(probability) of the pitch
energy_folder (Optional[Union[Path, str]]): The folder that contains or will contain energy.
durs_file (Optional[str]): String path to pickled durations location.
durs_type (Optional[str]): Type of durations. Currently supported only "aligner-based".
Expand Down Expand Up @@ -581,9 +581,9 @@ def general_collate_fn(self, batch):
pitches_lengths,
energies,
energies_lengths,
_,
voiced_masks,
p_voiceds,
_,
) = zip(*batch)

max_audio_len = max(audio_lengths).item()
Expand Down Expand Up @@ -633,7 +633,7 @@ def general_collate_fn(self, batch):
energy_length,
speaker_id,
voiced_mask,
p_voiceds,
p_voiced,
) = sample_tuple

audio = general_padding(audio, audio_len.item(), max_audio_len)
Expand All @@ -644,22 +644,27 @@ def general_collate_fn(self, batch):

if LogMel in self.sup_data_types_set:
log_mels.append(general_padding(log_mel, log_mel_len, max_log_mel_len, pad_value=log_mel_pad))

if Durations in self.sup_data_types_set:
durations_list.append(general_padding(durations, len(durations), max_durations_len))

if AlignPriorMatrix in self.sup_data_types_set:
align_prior_matrices[
i, : align_prior_matrix.shape[0], : align_prior_matrix.shape[1]
] = align_prior_matrix

if Pitch in self.sup_data_types_set:
pitches.append(general_padding(pitch, pitch_length.item(), max_pitches_len))

if Voiced_mask in self.sup_data_types_set:
voiced_masks.append(general_padding(voiced_mask, pitch_length.item(), max_pitches_len))

if P_voiced in self.sup_data_types_set:
p_voiceds.append(general_padding(voiced_mask, pitch_length.item(), max_pitches_len))
p_voiceds.append(general_padding(p_voiced, pitch_length.item(), max_pitches_len))

if Energy in self.sup_data_types_set:
energies.append(general_padding(energy, energy_length.item(), max_energies_len))

if SpeakerID in self.sup_data_types_set:
speaker_ids.append(speaker_id)

Expand Down

0 comments on commit 317739f

Please sign in to comment.