Skip to content

Commit

Permalink
[TTS] Change audio codec token type to TokenIndex (#7356)
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan <rlangman@nvidia.com>
  • Loading branch information
rlangman authored and yaoyu-33 committed Sep 5, 2023
1 parent f537f39 commit e26c41b
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 36 deletions.
11 changes: 5 additions & 6 deletions nemo/collections/tts/models/audio_codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import itertools
import random
from pathlib import Path
from typing import List, Tuple

Expand All @@ -34,7 +33,7 @@
from nemo.collections.tts.parts.utils.helpers import get_batch_size, get_num_workers
from nemo.core import ModelPT
from nemo.core.classes.common import PretrainedModelInfo, typecheck
from nemo.core.neural_types.elements import AudioSignal, EncodedRepresentation, Index, LengthsType
from nemo.core.neural_types.elements import AudioSignal, EncodedRepresentation, LengthsType, TokenIndex
from nemo.core.neural_types.neural_type import NeuralType
from nemo.core.optim.lr_scheduler import compute_max_steps, prepare_lr_scheduler
from nemo.utils import logging, model_utils
Expand Down Expand Up @@ -168,7 +167,7 @@ def decode_audio(self, inputs: torch.Tensor, input_len: torch.Tensor) -> Tuple[t
"encoded": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()),
"encoded_len": NeuralType(tuple('B'), LengthsType()),
},
output_types={"tokens": NeuralType(('B', 'C', 'T_encoded'), Index())},
output_types={"tokens": NeuralType(('B', 'C', 'T_encoded'), TokenIndex())},
)
def quantize(self, encoded: torch.Tensor, encoded_len: torch.Tensor) -> torch.Tensor:
"""Quantize the continuous encoded representation into a discrete
Expand All @@ -192,7 +191,7 @@ def quantize(self, encoded: torch.Tensor, encoded_len: torch.Tensor) -> torch.Te

@typecheck(
input_types={
"tokens": NeuralType(('B', 'C', 'T_encoded'), Index()),
"tokens": NeuralType(('B', 'C', 'T_encoded'), TokenIndex()),
"tokens_len": NeuralType(tuple('B'), LengthsType()),
},
output_types={"dequantized": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()),},
Expand Down Expand Up @@ -221,7 +220,7 @@ def dequantize(self, tokens: torch.Tensor, tokens_len: torch.Tensor) -> torch.Te
"audio_len": NeuralType(tuple('B'), LengthsType()),
},
output_types={
"tokens": NeuralType(('B', 'C', 'T_encoded'), Index()),
"tokens": NeuralType(('B', 'C', 'T_encoded'), TokenIndex()),
"tokens_len": NeuralType(tuple('B'), LengthsType()),
},
)
Expand All @@ -244,7 +243,7 @@ def encode(self, audio: torch.Tensor, audio_len: torch.Tensor) -> Tuple[torch.Te

@typecheck(
input_types={
"tokens": NeuralType(('B', 'C', 'T_encoded'), Index()),
"tokens": NeuralType(('B', 'C', 'T_encoded'), TokenIndex()),
"tokens_len": NeuralType(tuple('B'), LengthsType()),
},
output_types={
Expand Down
30 changes: 0 additions & 30 deletions nemo/collections/tts/modules/vector_quantization.py

This file was deleted.

0 comments on commit e26c41b

Please sign in to comment.