Skip to content

Commit

Permalink
Merge pull request #4 from lucasnewman/backtranslation
Browse files Browse the repository at this point in the history
Add trainers for the pretraining and backtranslation tasks
  • Loading branch information
lucidrains authored Aug 6, 2023
2 parents ae9cad1 + 26a67ec commit e46d583
Show file tree
Hide file tree
Showing 3 changed files with 698 additions and 27 deletions.
7 changes: 6 additions & 1 deletion spear_tts_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from spear_tts_pytorch.spear_tts_pytorch import (
TextToSemantic,
SpeechSpeechPretrainWrapper
SpeechSpeechPretrainWrapper,
SemanticToTextWrapper,
)
from spear_tts_pytorch.trainer import (
SpeechSpeechPretrainer,
SemanticToTextTrainer
)
169 changes: 143 additions & 26 deletions spear_tts_pytorch/spear_tts_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from contextlib import contextmanager
import math
from pathlib import Path
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch import Tensor, nn, einsum
from torch.nn import Module, ModuleList

Expand All @@ -28,6 +30,10 @@ def default(val, d):
def empty(t: Tensor):
return t.numel() == 0

@contextmanager
def null_context():
yield

def set_eos_id(t: Tensor, eos_id: int, pad_id: int):
eos_indices = ((t == pad_id).cumsum(dim = -1) == 0).sum(dim = -1, keepdim = True).long()

Expand All @@ -38,6 +44,10 @@ def set_eos_id(t: Tensor, eos_id: int, pad_id: int):
t[batch_range, eos_indices] = eos_id
return t

def batch_unique_consecutive(t, pad_value = 0.):
unique_arr = [torch.unique_consecutive(el) for el in t.unbind(dim = 0)]
return pad_sequence(unique_arr, batch_first = True, padding_value = pad_value)

# sampling helpers

def eval_decorator(fn):
Expand Down Expand Up @@ -294,6 +304,22 @@ def forward(

return self.final_norm(x)

def model_forward_with_context(
*,
fn,
args,
freeze,
):
encoding_context = null_context if not freeze else torch.no_grad

with encoding_context():
enc = fn(*args)

if freeze:
enc.detach_()

return enc

# class

SpeechOrTextLiteral = Union[
Expand Down Expand Up @@ -322,11 +348,13 @@ def __init__(
dim_head = 64,
heads = 8,
attn_dropout = 0.,
ff_mult = 4,
ff_dropout = 0.,
semantic_pad_id = -1,
text_pad_id = 0,
autoset_semantic_eos_id = True,
autoset_text_eos_id = True,
freeze_encoder = False
):
super().__init__()
self.dim = dim
Expand Down Expand Up @@ -405,6 +433,7 @@ def __init__(
heads = heads,
depth = source_depth,
attn_dropout = attn_dropout,
ff_mult = ff_mult,
ff_dropout = ff_dropout,
causal = False
)
Expand All @@ -415,16 +444,22 @@ def __init__(
heads = heads,
depth = source_depth,
attn_dropout = attn_dropout,
ff_mult = ff_mult,
ff_dropout = ff_dropout,
causal = True,
cross_attend = True
)

self.freeze_encoder = freeze_encoder

def load(self, path, strict = True):
# Return pkg so that if this function gets called from within a Trainer function call,
# the trainer can also access the package loaded from the checkpoint.
path = Path(path)
assert path.exists()
pkg = torch.load(str(path), map_location = 'cpu')
self.load_state_dict(pkg, strict = strict)
self.load_state_dict(pkg['model'], strict = strict)
return pkg

@property
def device(self):
Expand All @@ -443,7 +478,9 @@ def generate(
filter_logits_fn = top_k,
filter_thres = 0.9,
source_mask: Optional[Tensor] = None,
max_length = 2048
max_length = 2048,
beam_search_decode = False,
beam_size = 4,
):
if is_bearable(source, List[str]):
assert exists(self.tokenizer_encode)
Expand Down Expand Up @@ -474,7 +511,7 @@ def generate(

if not exists(source_mask) and source.dtype == torch.long:
source_mask = source != source_pad_id

# source embedding

source_emb = source_token_emb(source)
Expand All @@ -488,35 +525,73 @@ def generate(

# loop to decode

for _ in tqdm(range(max_length)):
target_emb = target_token_emb(target)
target_emb = torch.cat((start_token, target_emb), dim = 1)
if not beam_search_decode:
for _ in tqdm(range(max_length)):
target_emb = target_token_emb(target)
target_emb = torch.cat((start_token, target_emb), dim = 1)

# target attention
# target attention

target_emb = self.target_transformer(target_emb, context = source_emb, context_mask = source_mask)
target_emb = self.target_transformer(target_emb, context = source_emb, context_mask = source_mask)

# decoder logits
# decoder logits

logits = target_to_logit(target_emb)
logits = target_to_logit(target_emb)

logits = logits[:, -1]
logits = filter_logits_fn(logits, thres = filter_thres)
logits = logits[:, -1]

logits = filter_logits_fn(logits, thres = filter_thres)

sampled = gumbel_sample(logits, temperature = temperature)
target, _ = pack((target, sampled), 'b *')
sampled = gumbel_sample(logits, temperature = temperature)
target, _ = pack((target, sampled), 'b *')

if not self.autoset_eos_id[target_type]:
continue
if not self.autoset_eos_id[target_type]:
continue

is_eos = target == target_eos_id
is_eos = target == target_eos_id

if not is_eos.any(dim = -1).all():
continue
if not is_eos.any(dim = -1).all():
continue

mask = is_eos.cumsum(dim = -1) == 0
target = target.masked_fill(~mask, target_pad_id)
break
mask = is_eos.cumsum(dim = -1) == 0
target = target.masked_fill(~mask, target_pad_id)
break
else:
beam = [(target, 0.0)]

for _ in tqdm(range(max_length)):
all_candidates = []

for sentence, sentence_prob in beam:
target_emb = target_token_emb(sentence)
target_emb = torch.cat((start_token, target_emb), dim = 1)

# target attention

target_emb = self.target_transformer(target_emb, context = source_emb, context_mask = source_mask)

# decoder logits

logits = target_to_logit(target_emb)
logits = logits[:, -1]

log_probs = torch.log_softmax(logits / max(temperature, 1e-10), dim = -1)
topk_log_probs, topk_ids = log_probs.topk(beam_size, dim = -1)

for i in range(beam_size):
candidate = torch.cat([sentence, topk_ids[..., i:i + 1]], dim = -1)
candidate_prob = sentence_prob + topk_log_probs[..., i]
all_candidates.append((candidate, candidate_prob))

ordered = sorted(all_candidates, key = lambda tup: tup[1], reverse = True)
beam = ordered[:beam_size]

# check if we've hit eos for all sequences
all_eos = all([((sentence == target_eos_id).any(dim = -1)).all() for sentence, _ in beam])
if all_eos:
break

target = beam[0][0]

return target

Expand All @@ -529,6 +604,7 @@ def forward(
source_type: SpeechOrTextLiteral,
target_type: SpeechOrTextLiteral,
source_mask: Optional[Tensor] = None,
target_mask: Optional[Tensor] = None,
return_loss = False
):
if is_bearable(source, List[str]):
Expand Down Expand Up @@ -564,12 +640,18 @@ def forward(
target_eos_id = self.eos_id[target_type]
target = set_eos_id(target, target_eos_id, pad_id = target_pad_id)

# if source mask is not passed in
# if source/target mask is not passed in
# automatically derive by the padding id of the modality

if not exists(source_mask) and source.dtype == torch.long:
source_mask = source != source_pad_id

if not exists(target_mask) and target.dtype == torch.long:
target_mask = target != target_pad_id

# attend to bos
target_mask = F.pad(target_mask, (1, 0), value = True)

# embedding

source_emb = source_token_emb(source)
Expand All @@ -581,11 +663,15 @@ def forward(

# source attention

source_emb = self.source_transformer(source_emb, mask = source_mask)
model_forward_with_context(
fn = self.source_transformer,
args = (source_emb, source_mask),
freeze = self.freeze_encoder
)

# target attention

target_emb = self.target_transformer(target_emb, context = source_emb, context_mask = source_mask)
target_emb = self.target_transformer(target_emb, mask = target_mask, context = source_emb, context_mask = source_mask)

# decoder logits

Expand Down Expand Up @@ -636,7 +722,6 @@ def __init__(

self.model = model
self.wav2vec = default(wav2vec, model.wav2vec)
assert exists(self.wav2vec)

self.deletion_prob = deletion_prob
self.reconstruct_seq = reconstruct_seq # whether to reconstruct the entire sequence, or just output the deleted ones in order
Expand All @@ -648,6 +733,8 @@ def forward(
is_raw_audio = x.dtype == torch.float

if is_raw_audio:
assert exists(self.wav2vec)

with torch.no_grad():
self.wav2vec.eval()
x = self.wav2vec(x, flatten = False)
Expand All @@ -673,3 +760,33 @@ def forward(
)

return loss

# wrapper for backtranslation task

class SemanticToTextWrapper(nn.Module):
@beartype
def __init__(
self,
model: TextToSemantic
):
super().__init__()

self.model = model

def forward(
self,
semantic_token_ids,
grapheme_token_ids,
):
source = semantic_token_ids
target = grapheme_token_ids

loss = self.model(
source, target,
source_type = 'speech',
target_type = 'text',
return_loss = True
)

return loss

Loading

0 comments on commit e46d583

Please sign in to comment.